A2C
A synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). It uses multiple workers to avoid the use of a replay buffer.
Warning
If you find training unstable or want to match performance of stable-baselines A2C, consider using
RMSpropTFLike
optimizer from stable_baselines3.common.sb2_compat.rmsprop_tf_like
.
You can change optimizer with A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))
.
Read more here.
Notes
Original paper: https://arxiv.org/abs/1602.01783
OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
Can I use?
Recurrent policies: ❌
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
✔️ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
✔️ |
✔️ |
MultiBinary |
✔️ |
✔️ |
Dict |
❌ |
✔️ |
Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo repository.
Train a A2C agent on CartPole-v1
using 4 environments.
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
del model # remove to demonstrate saving and loading
model = A2C.load("a2c_cartpole")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
Note
A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using SubprocVecEnv
instead of the default DummyVecEnv
:
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = A2C("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see Vectorized Environments, Issue #1245 or the Multiprocessing notebook.
Results
Atari Games
The complete learning curves are available in the associated PR #110.
PyBullet Environments
Results on the PyBullet benchmark (2M steps) using 6 seeds. The complete learning curves are available in the associated issue #48.
Note
Hyperparameters from the gSDE paper were used (as they are tuned for PyBullet envs).
Gaussian means that the unstructured Gaussian noise is used for exploration, gSDE (generalized State-Dependent Exploration) is used otherwise.
Environments |
A2C |
A2C |
PPO |
PPO |
---|---|---|---|---|
Gaussian |
gSDE |
Gaussian |
gSDE |
|
HalfCheetah |
2003 +/- 54 |
2032 +/- 122 |
1976 +/- 479 |
2826 +/- 45 |
Ant |
2286 +/- 72 |
2443 +/- 89 |
2364 +/- 120 |
2782 +/- 76 |
Hopper |
1627 +/- 158 |
1561 +/- 220 |
1567 +/- 339 |
2512 +/- 21 |
Walker2D |
577 +/- 65 |
839 +/- 56 |
1230 +/- 147 |
2019 +/- 64 |
How to replicate the results?
Clone the rl-zoo repo:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results (here for PyBullet envs only):
python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results
python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C
Parameters
- class stable_baselines3.a2c.A2C(policy, env, learning_rate=0.0007, n_steps=5, gamma=0.99, gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, rms_prop_eps=1e-05, use_rms_prop=True, use_sde=False, sde_sample_freq=-1, rollout_buffer_class=None, rollout_buffer_kwargs=None, normalize_advantage=False, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]
Advantage Actor Critic (A2C)
Paper: https://arxiv.org/abs/1602.01783 Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and and Stable Baselines (https://github.com/hill-a/stable-baselines)
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
- Parameters:
policy (ActorCriticPolicy) – The policy model to use (MlpPolicy, CnnPolicy, …)
env (Env | VecEnv | str) – The environment to learn from (if registered in Gym, can be str)
learning_rate (float | Callable[[float], float]) – The learning rate, it can be a function of the current progress remaining (from 1 to 0)
n_steps (int) – The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
gamma (float) – Discount factor
gae_lambda (float) – Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1.
ent_coef (float) – Entropy coefficient for the loss calculation
vf_coef (float) – Value function coefficient for the loss calculation
max_grad_norm (float) – The maximum value for the gradient clipping
rms_prop_eps (float) – RMSProp epsilon. It stabilizes square root computation in denominator of RMSProp update
use_rms_prop (bool) – Whether to use RMSprop (default) or Adam as optimizer
use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)
sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)
rollout_buffer_class (Type[RolloutBuffer] | None) – Rollout buffer class to use. If
None
, it will be automatically selected.rollout_buffer_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the rollout buffer on creation.
normalize_advantage (bool) – Whether to normalize or not the advantage
stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over
tensorboard_log (str | None) – the log location for tensorboard (if None, no logging)
policy_kwargs (Dict[str, Any] | None) – additional arguments to be passed to the policy on creation
verbose (int) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages
seed (int | None) – Seed for the pseudo random generators
device (device | str) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.
_init_setup_model (bool) – Whether or not to build the network at the creation of the instance
- collect_rollouts(env, callback, rollout_buffer, n_rollout_steps)
Collect experiences using the current policy and fill a
RolloutBuffer
. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning.- Parameters:
env (VecEnv) – The training environment
callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)
rollout_buffer (RolloutBuffer) – Buffer to fill with rollouts
n_rollout_steps (int) – Number of experiences to collect per environment
- Returns:
True if function returned with at least n_rollout_steps collected, False if callback terminated rollout prematurely.
- Return type:
bool
- get_env()
Returns the current environment (can be None if not defined).
- Returns:
The current environment
- Return type:
VecEnv | None
- get_parameters()
Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).
- Returns:
Mapping of from names of the objects to PyTorch state-dicts.
- Return type:
Dict[str, Dict]
- get_vec_normalize_env()
Return the
VecNormalize
wrapper of the training env if it exists.- Returns:
The
VecNormalize
env.- Return type:
VecNormalize | None
- learn(total_timesteps, callback=None, log_interval=100, tb_log_name='A2C', reset_num_timesteps=True, progress_bar=False)[source]
Return a trained model.
- Parameters:
total_timesteps (int) – The total number of samples (env steps) to train on
callback (None | Callable | List[BaseCallback] | BaseCallback) – callback(s) called at every step with state of the algorithm.
log_interval (int) – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.
tb_log_name (str) – the name of the run for TensorBoard logging
reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)
progress_bar (bool) – Display a progress bar using tqdm and rich.
self (SelfA2C) –
- Returns:
the trained model
- Return type:
SelfA2C
- classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)
Load the model from a zip-file. Warning:
load
re-creates the model from scratch, it does not update it in-place! For an in-place load useset_parameters
instead.- Parameters:
path (str | Path | BufferedIOBase) – path to the file (or a file-like) where to load the agent from
env (Env | VecEnv | None) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment
device (device | str) – Device on which the code should run.
custom_objects (Dict[str, Any] | None) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in
keras.models.load_model
. Useful when you have an object in file that can not be deserialized.print_system_info (bool) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)
force_reset (bool) – Force call to
reset()
before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597kwargs – extra arguments to change the model when loading
- Returns:
new model instance with loaded parameters
- Return type:
SelfBaseAlgorithm
- predict(observation, state=None, episode_start=None, deterministic=False)
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation (ndarray | Dict[str, ndarray]) – the input observation
state (Tuple[ndarray, ...] | None) – The last hidden states (can be None, used in recurrent policies)
episode_start (ndarray | None) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.
deterministic (bool) – Whether or not to return deterministic actions.
- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- Return type:
Tuple[ndarray, Tuple[ndarray, …] | None]
- save(path, exclude=None, include=None)
Save all the attributes of the object and the model parameters in a zip-file.
- Parameters:
path (str | Path | BufferedIOBase) – path to the file where the rl agent should be saved
exclude (Iterable[str] | None) – name of parameters that should be excluded in addition to the default ones
include (Iterable[str] | None) – name of parameters that might be excluded but should be included anyway
- Return type:
None
- set_env(env, force_reset=True)
Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters:
env (Env | VecEnv) – The environment for learning a policy
force_reset (bool) – Force call to
reset()
before training to avoid unexpected behavior. See issue https://github.com/DLR-RM/stable-baselines3/issues/597
- Return type:
None
- set_logger(logger)
Setter for for logger object.
Warning
When passing a custom logger object, this will overwrite
tensorboard_log
andverbose
settings passed to the constructor.- Parameters:
logger (Logger) –
- Return type:
None
- set_parameters(load_path_or_dict, exact_match=True, device='auto')
Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see
get_parameters
).- Parameters:
load_path_or_iter – Location of the saved data (path or file-like, see
save
), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned bytorch.nn.Module.state_dict()
.exact_match (bool) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.
device (device | str) – Device on which the code should run.
load_path_or_dict (str | Dict[str, Tensor]) –
- Return type:
None
- set_random_seed(seed=None)
Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters:
seed (int | None) –
- Return type:
None
A2C Policies
- stable_baselines3.a2c.MlpPolicy
alias of
ActorCriticPolicy
- class stable_baselines3.common.policies.ActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (Space) – Observation space
action_space (Space) – Action space
lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)
net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.
activation_fn (Type[Module]) – Activation function
ortho_init (bool) – Whether to use or not orthogonal initialization
use_sde (bool) – Whether to use State Dependent Exploration or not
log_std_init (float) – Initial value for the log standard deviation
full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE
use_expln (bool) – Use
expln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.
features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.
features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.
share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.
normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)
optimizer_class (Type[Optimizer]) – The optimizer to use,
th.optim.Adam
by defaultoptimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
- evaluate_actions(obs, actions)[source]
Evaluate actions according to the current policy, given the observations.
- Parameters:
obs (Tensor | Dict[str, Tensor]) – Observation
actions (Tensor) – Actions
- Returns:
estimated value, log likelihood of taking those actions and entropy of the action distribution.
- Return type:
Tuple[Tensor, Tensor, Tensor | None]
- extract_features(obs, features_extractor=None)[source]
Preprocess the observation if needed and extract features.
- Parameters:
obs (Tensor | Dict[str, Tensor]) – Observation
features_extractor (BaseFeaturesExtractor | None) – The features extractor to use. If None, then
self.features_extractor
is used.
- Returns:
The extracted features. If features extractor is not shared, returns a tuple with the features for the actor and the features for the critic.
- Return type:
Tensor | Tuple[Tensor, Tensor]
- forward(obs, deterministic=False)[source]
Forward pass in all the networks (actor and critic)
- Parameters:
obs (Tensor) – Observation
deterministic (bool) – Whether to sample or use deterministic actions
- Returns:
action, value and log probability of the action
- Return type:
Tuple[Tensor, Tensor, Tensor]
- get_distribution(obs)[source]
Get the current policy distribution given the observations.
- Parameters:
obs (Tensor | Dict[str, Tensor]) –
- Returns:
the action distribution.
- Return type:
- predict_values(obs)[source]
Get the estimated values according to the current policy given the observations.
- Parameters:
obs (Tensor | Dict[str, Tensor]) – Observation
- Returns:
the estimated values.
- Return type:
Tensor
- reset_noise(n_envs=1)[source]
Sample new weights for the exploration matrix.
- Parameters:
n_envs (int) –
- Return type:
None
- stable_baselines3.a2c.CnnPolicy
alias of
ActorCriticCnnPolicy
- class stable_baselines3.common.policies.ActorCriticCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (Space) – Observation space
action_space (Space) – Action space
lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)
net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.
activation_fn (Type[Module]) – Activation function
ortho_init (bool) – Whether to use or not orthogonal initialization
use_sde (bool) – Whether to use State Dependent Exploration or not
log_std_init (float) – Initial value for the log standard deviation
full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE
use_expln (bool) – Use
expln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.
features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.
features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.
share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.
normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)
optimizer_class (Type[Optimizer]) – The optimizer to use,
th.optim.Adam
by defaultoptimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer
- stable_baselines3.a2c.MultiInputPolicy
alias of
MultiInputActorCriticPolicy
- class stable_baselines3.common.policies.MultiInputActorCriticPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.Tanh'>, ortho_init=True, use_sde=False, log_std_init=0.0, full_std=True, use_expln=False, squash_output=False, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, share_features_extractor=True, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None)[source]
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes.
- Parameters:
observation_space (Dict) – Observation space (Tuple)
action_space (Space) – Action space
lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)
net_arch (List[int] | Dict[str, List[int]] | None) – The specification of the policy and value networks.
activation_fn (Type[Module]) – Activation function
ortho_init (bool) – Whether to use or not orthogonal initialization
use_sde (bool) – Whether to use State Dependent Exploration or not
log_std_init (float) – Initial value for the log standard deviation
full_std (bool) – Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE
use_expln (bool) – Use
expln()
function instead ofexp()
to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.squash_output (bool) – Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE.
features_extractor_class (Type[BaseFeaturesExtractor]) – Uses the CombinedExtractor
features_extractor_kwargs (Dict[str, Any] | None) – Keyword arguments to pass to the features extractor.
share_features_extractor (bool) – If True, the features extractor is shared between the policy and value networks.
normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)
optimizer_class (Type[Optimizer]) – The optimizer to use,
th.optim.Adam
by defaultoptimizer_kwargs (Dict[str, Any] | None) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer