PPO

The Proximal Policy Optimization algorithm combines ideas from A2C (having multiple workers) and TRPO (it uses a trust region to improve the actor).

The main idea is that after an update, the new policy should be not too far from the old policy. For that, ppo uses clipping to avoid too large update.

Note

PPO contains several modifications from the original algorithm not documented by OpenAI: advantages are normalized and value function can be also clipped.

Notes

Can I use?

Note

A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html

However we advise users to start with simple frame-stacking as a simpler, faster and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity–VmlldzoxOTI4NjE4 See also Procgen paper appendix Fig 11.. In practice, you can stack multiple observations using VecFrameStack.

  • Recurrent policies: ❌

  • Multi processing: ✔️

  • Gym spaces:

Space

Action

Observation

Discrete

✔️

✔️

Box

✔️

✔️

MultiDiscrete

✔️

✔️

MultiBinary

✔️

✔️

Dict

✔️

Example

This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo repository.

Train a PPO agent on CartPole-v1 using 4 environments.

import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)

model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")

del model # remove to demonstrate saving and loading

model = PPO.load("ppo_cartpole")

obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

Results

Atari Games

The complete learning curves are available in the associated PR #110.

PyBullet Environments

Results on the PyBullet benchmark (2M steps) using 6 seeds. The complete learning curves are available in the associated issue #48.

Note

Hyperparameters from the gSDE paper were used (as they are tuned for PyBullet envs).

Gaussian means that the unstructured Gaussian noise is used for exploration, gSDE (generalized State-Dependent Exploration) is used otherwise.

Environments

A2C

A2C

PPO

PPO

Gaussian

gSDE

Gaussian

gSDE

HalfCheetah

2003 +/- 54

2032 +/- 122

1976 +/- 479

2826 +/- 45

Ant

2286 +/- 72

2443 +/- 89

2364 +/- 120

2782 +/- 76

Hopper

1627 +/- 158

1561 +/- 220

1567 +/- 339

2512 +/- 21

Walker2D

577 +/- 65

839 +/- 56

1230 +/- 147

2019 +/- 64

How to replicate the results?

Clone the rl-zoo repo:

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark (replace $ENV_ID by the envs mentioned above):

python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000

Plot the results (here for PyBullet envs only):

python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results
python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO

Parameters

class stable_baselines3.ppo.PPO(policy, env, learning_rate=0.0003, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, normalize_advantage=True, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, use_sde=False, sde_sample_freq=-1, rollout_buffer_class=None, rollout_buffer_kwargs=None, target_kl=None, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]

Proximal Policy Optimization algorithm (PPO) (clip version)

Paper: https://arxiv.org/abs/1707.06347 Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

Parameters:
  • policy (ActorCriticPolicy) – The policy model to use (MlpPolicy, CnnPolicy, …)

  • env (Env | VecEnv | str) – The environment to learn from (if registered in Gym, can be str)

  • learning_rate (float | Callable[[float], float]) – The learning rate, it can be a function of the current progress remaining (from 1 to 0)

  • n_steps (int) – The number of steps to run for each environment per update (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) See https://github.com/pytorch/pytorch/issues/29372

  • batch_size (int) – Minibatch size

  • n_epochs (int) – Number of epoch when optimizing the surrogate loss

  • gamma (float) – Discount factor

  • gae_lambda (float) – Factor for trade-off of bias vs variance for Generalized Advantage Estimator

  • clip_range (float | Callable[[float], float]) – Clipping parameter, it can be a function of the current progress remaining (from 1 to 0).

  • clip_range_vf (None | float | Callable[[float], float]) – Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling.

  • normalize_advantage (bool) – Whether to normalize or not the advantage

  • ent_coef (float) – Entropy coefficient for the loss calculation

  • vf_coef (float) – Value function coefficient for the loss calculation

  • max_grad_norm (float) – The maximum value for the gradient clipping

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • rollout_buffer_class (Type[RolloutBuffer] | None) – Rollout buffer class to use. If None, it will be automatically selected.

  • rollout_buffer_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the rollout buffer on creation

  • target_kl (float | None) – Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div.

  • stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over

  • tensorboard_log (str | None) – the log location for tensorboard (if None, no logging)

  • policy_kwargs (Dict[str, Any] | None) – additional arguments to be passed to the policy on creation

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages

  • seed (int | None) – Seed for the pseudo random generators

  • device (device | str) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.

  • _init_setup_model (bool) – Whether or not to build the network at the creation of the instance

collect_rollouts(env, callback, rollout_buffer, n_rollout_steps)

Collect experiences using the current policy and fill a RolloutBuffer. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning.

Parameters:
  • env (VecEnv) – The training environment

  • callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)

  • rollout_buffer (RolloutBuffer) – Buffer to fill with rollouts

  • n_rollout_steps (int) – Number of experiences to collect per environment

Returns:

True if function returned with at least n_rollout_steps collected, False if callback terminated rollout prematurely.

Return type:

bool

get_env()

Returns the current environment (can be None if not defined).

Returns:

The current environment

Return type:

VecEnv | None

get_parameters()

Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).

Returns:

Mapping of from names of the objects to PyTorch state-dicts.

Return type:

Dict[str, Dict]

get_vec_normalize_env()

Return the VecNormalize wrapper of the training env if it exists.

Returns:

The VecNormalize env.

Return type:

VecNormalize | None

learn(total_timesteps, callback=None, log_interval=1, tb_log_name='PPO', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps (int) – The total number of samples (env steps) to train on

  • callback (None | Callable | List[BaseCallback] | BaseCallback) – callback(s) called at every step with state of the algorithm.

  • log_interval (int) – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.

  • tb_log_name (str) – the name of the run for TensorBoard logging

  • reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)

  • progress_bar (bool) – Display a progress bar using tqdm and rich.

  • self (SelfPPO) –

Returns:

the trained model

Return type:

SelfPPO

classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)

Load the model from a zip-file. Warning: load re-creates the model from scratch, it does not update it in-place! For an in-place load use set_parameters instead.

Parameters:
  • path (str | Path | BufferedIOBase) – path to the file (or a file-like) where to load the agent from

  • env (Env | VecEnv | None) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment

  • device (device | str) – Device on which the code should run.

  • custom_objects (Dict[str, Any] | None) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in keras.models.load_model. Useful when you have an object in file that can not be deserialized.

  • print_system_info (bool) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)

  • force_reset (bool) – Force call to reset() before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597

  • kwargs – extra arguments to change the model when loading

Returns:

new model instance with loaded parameters

Return type:

SelfBaseAlgorithm

property logger: Logger

Getter for the logger object.

predict(observation, state=None, episode_start=None, deterministic=False)

Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).

Parameters:
  • observation (ndarray | Dict[str, ndarray]) – the input observation

  • state (Tuple[ndarray, ...] | None) – The last hidden states (can be None, used in recurrent policies)

  • episode_start (ndarray | None) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.

  • deterministic (bool) – Whether or not to return deterministic actions.

Returns:

the model’s action and the next hidden state (used in recurrent policies)

Return type:

Tuple[ndarray, Tuple[ndarray, …] | None]

save(path, exclude=None, include=None)

Save all the attributes of the object and the model parameters in a zip-file.

Parameters:
  • path (str | Path | BufferedIOBase) – path to the file where the rl agent should be saved

  • exclude (Iterable[str] | None) – name of parameters that should be excluded in addition to the default ones

  • include (Iterable[str] | None) – name of parameters that might be excluded but should be included anyway

Return type:

None

set_env(env, force_reset=True)

Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space

Parameters:
Return type:

None

set_logger(logger)

Setter for for logger object.

Warning

When passing a custom logger object, this will overwrite tensorboard_log and verbose settings passed to the constructor.

Parameters:

logger (Logger) –

Return type:

None

set_parameters(load_path_or_dict, exact_match=True, device='auto')

Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see get_parameters).

Parameters:
  • load_path_or_iter – Location of the saved data (path or file-like, see save), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned by torch.nn.Module.state_dict().

  • exact_match (bool) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.

  • device (device | str) – Device on which the code should run.

  • load_path_or_dict (str | Dict[str, Tensor]) –

Return type:

None

set_random_seed(seed=None)

Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)

Parameters:

seed (int | None) –

Return type:

None

train()[source]

Update policy using the currently gathered rollout buffer.

Return type:

None

PPO Policies

stable_baselines3.ppo.MlpPolicy

alias of ActorCriticPolicy

class stable_baselines3.common.policies.ActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

evaluate_actions(obs, actions)[source]

Evaluate actions according to the current policy, given the observations.

Parameters:
  • obs (Tensor | Dict[str, Tensor]) – Observation

  • actions (Tensor) – Actions

Returns:

estimated value, log likelihood of taking those actions and entropy of the action distribution.

Return type:

Tuple[Tensor, Tensor, Tensor | None]

extract_features(obs, features_extractor=None)[source]

Preprocess the observation if needed and extract features.

Parameters:
  • obs (Tensor | Dict[str, Tensor]) – Observation

  • features_extractor (BaseFeaturesExtractor | None) – The features extractor to use. If None, then self.features_extractor is used.

Returns:

The extracted features. If features extractor is not shared, returns a tuple with the features for the actor and the features for the critic.

Return type:

Tensor | Tuple[Tensor, Tensor]

forward(obs, deterministic=False)[source]

Forward pass in all the networks (actor and critic)

Parameters:
  • obs (Tensor) – Observation

  • deterministic (bool) – Whether to sample or use deterministic actions

Returns:

action, value and log probability of the action

Return type:

Tuple[Tensor, Tensor, Tensor]

get_distribution(obs)[source]

Get the current policy distribution given the observations.

Parameters:

obs (Tensor | Dict[str, Tensor]) –

Returns:

the action distribution.

Return type:

Distribution

predict_values(obs)[source]

Get the estimated values according to the current policy given the observations.

Parameters:

obs (Tensor | Dict[str, Tensor]) – Observation

Returns:

the estimated values.

Return type:

Tensor

reset_noise(n_envs=1)[source]

Sample new weights for the exploration matrix.

Parameters:

n_envs (int) –

Return type:

None

stable_baselines3.ppo.CnnPolicy

alias of ActorCriticCnnPolicy

class stable_baselines3.common.policies.ActorCriticCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

stable_baselines3.ppo.MultiInputPolicy

alias of MultiInputActorCriticPolicy

class stable_baselines3.common.policies.MultiInputActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]

MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.

Parameters:
  • observation_space (Dict) – Observation space (Tuple)

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • ortho_init (bool) – Whether to use or not orthogonal initialization

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE

  • use_expln (bool) – Use expln() function instead of exp() to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Uses the CombinedExtractor

  • features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.

  • share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer