A2C¶
A synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). It uses multiple workers to avoid the use of a replay buffer.
Warning
If you find training unstable or want to match performance of stable-baselines A2C, consider using
RMSpropTFLike
optimizer from stable_baselines3.common.sb2_compat.rmsprop_tf_like
.
You can change optimizer with A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))
.
Read more here.
Notes¶
Original paper: https://arxiv.org/abs/1602.01783
OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
Can I use?¶
Recurrent policies: ✔️
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
✔️ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
✔️ |
✔️ |
MultiBinary |
✔️ |
✔️ |
Example¶
Train a A2C agent on CartPole-v1
using 4 environments.
import gym
from stable_baselines3 import A2C
from stable_baselines3.a2c import MlpPolicy
from stable_baselines3.common.cmd_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
model = A2C(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
del model # remove to demonstrate saving and loading
model = A2C.load("a2c_cartpole")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters¶
-
class
stable_baselines3.a2c.
A2C
(policy: Union[str, Type[stable_baselines3.common.policies.ActorCriticPolicy]], env: Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv, str], learning_rate: Union[float, Callable] = 0.0007, n_steps: int = 5, gamma: float = 0.99, gae_lambda: float = 1.0, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, rms_prop_eps: float = 1e-05, use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = - 1, normalize_advantage: bool = False, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[torch.device, str] = 'auto', _init_setup_model: bool = True)[source]¶ Advantage Actor Critic (A2C)
Paper: https://arxiv.org/abs/1602.01783 Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and and Stable Baselines (https://github.com/hill-a/stable-baselines)
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
- Parameters
policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, …)
env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
learning_rate – (float or callable) The learning rate, it can be a function
n_steps – (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
gamma – (float) Discount factor
gae_lambda – (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1.
ent_coef – (float) Entropy coefficient for the loss calculation
vf_coef – (float) Value function coefficient for the loss calculation
max_grad_norm – (float) The maximum value for the gradient clipping
rms_prop_eps – (float) RMSProp epsilon. It stabilizes square root computation in denominator of RMSProp update
use_rms_prop – (bool) Whether to use RMSprop (default) or Adam as optimizer
use_sde – (bool) Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)
sde_sample_freq – (int) Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)
normalize_advantage – (bool) Whether to normalize or not the advantage
tensorboard_log – (str) the log location for tensorboard (if None, no logging)
create_eval_env – (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)
policy_kwargs – (dict) additional arguments to be passed to the policy on creation
verbose – (int) the verbosity level: 0 no output, 1 info, 2 debug
seed – (int) Seed for the pseudo random generators
device – (str or th.device) Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.
_init_setup_model – (bool) Whether or not to build the network at the creation of the instance
-
collect_rollouts
(env: stable_baselines3.common.vec_env.base_vec_env.VecEnv, callback: stable_baselines3.common.callbacks.BaseCallback, rollout_buffer: stable_baselines3.common.buffers.RolloutBuffer, n_rollout_steps: int) → bool¶ Collect rollouts using the current policy and fill a RolloutBuffer.
- Parameters
env – (VecEnv) The training environment
callback – (BaseCallback) Callback that will be called at each step (and at the beginning and end of the rollout)
rollout_buffer – (RolloutBuffer) Buffer to fill with rollouts
n_steps – (int) Number of experiences to collect per environment
- Returns
(bool) True if function returned with at least n_rollout_steps collected, False if callback terminated rollout prematurely.
-
excluded_save_params
() → List[str]¶ Returns the names of the parameters that should be excluded by default when saving the model.
- Returns
([str]) List of parameters that should be excluded from save
-
get_env
() → Optional[stable_baselines3.common.vec_env.base_vec_env.VecEnv]¶ Returns the current environment (can be None if not defined).
- Returns
(Optional[VecEnv]) The current environment
-
get_torch_variables
() → Tuple[List[str], List[str]]¶ cf base class
-
get_vec_normalize_env
() → Optional[stable_baselines3.common.vec_env.vec_normalize.VecNormalize]¶ Return the
VecNormalize
wrapper of the training env if it exists. :return: Optional[VecNormalize] TheVecNormalize
env.
-
learn
(total_timesteps: int, callback: Union[None, Callable, List[stable_baselines3.common.callbacks.BaseCallback], stable_baselines3.common.callbacks.BaseCallback] = None, log_interval: int = 100, eval_env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, eval_freq: int = - 1, n_eval_episodes: int = 5, tb_log_name: str = 'A2C', eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True) → stable_baselines3.a2c.a2c.A2C[source]¶ Return a trained model.
- Parameters
total_timesteps – (int) The total number of samples (env steps) to train on
callback – (MaybeCallback) callback(s) called at every step with state of the algorithm.
log_interval – (int) The number of timesteps before logging.
tb_log_name – (str) the name of the run for TensorBoard logging
eval_env – (gym.Env) Environment that will be used to evaluate the agent
eval_freq – (int) Evaluate the agent every
eval_freq
timesteps (this may vary a little)n_eval_episodes – (int) Number of episode to evaluate the agent
eval_log_path – (Optional[str]) Path to a folder where the evaluations will be saved
reset_num_timesteps – (bool) whether or not to reset the current timestep number (used in logging)
- Returns
(BaseAlgorithm) the trained model
-
classmethod
load
(load_path: str, env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, **kwargs) → stable_baselines3.common.base_class.BaseAlgorithm¶ Load the model from a zip-file
- Parameters
load_path – the location of the saved data
env – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment
kwargs – extra arguments to change the model when loading
-
predict
(observation: numpy.ndarray, state: Optional[numpy.ndarray] = None, mask: Optional[numpy.ndarray] = None, deterministic: bool = False) → Tuple[numpy.ndarray, Optional[numpy.ndarray]]¶ Get the model’s action(s) from an observation
- Parameters
observation – (np.ndarray) the input observation
state – (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
mask – (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
deterministic – (bool) Whether or not to return deterministic actions.
- Returns
(Tuple[np.ndarray, Optional[np.ndarray]]) the model’s action and the next state (used in recurrent policies)
-
save
(path: Union[str, pathlib.Path, io.BufferedIOBase], exclude: Optional[Iterable[str]] = None, include: Optional[Iterable[str]] = None) → None¶ Save all the attributes of the object and the model parameters in a zip-file.
- Parameters
pathlib.Path, io.BufferedIOBase]) ((Union[str,) – path to the file where the rl agent should be saved
exclude – name of parameters that should be excluded in addition to the default one
include – name of parameters that might be excluded but should be included anyway
-
set_env
(env: Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]) → None¶ Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters
env – The environment for learning a policy
-
set_random_seed
(seed: Optional[int] = None) → None¶ Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters
seed – (int)