Abstract base classes for RL algorithms.

Base RL Class

Common interface for all the RL algorithms

class stable_baselines3.common.base_class.BaseAlgorithm(policy, env, learning_rate, policy_kwargs=None, stats_window_size=100, tensorboard_log=None, verbose=0, device='auto', support_multi_env=False, monitor_wrapper=True, seed=None, use_sde=False, sde_sample_freq=-1, supported_action_spaces=None)[source]

The base of RL algorithms

Parameters:
  • policy (Union[str, Type[BasePolicy]]) – The policy model to use (MlpPolicy, CnnPolicy, …)

  • env (Union[Env, VecEnv, str, None]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • learning_rate (Union[float, Callable[[float], float]]) – learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0)

  • policy_kwargs (Optional[Dict[str, Any]]) – Additional arguments to be passed to the policy on creation

  • stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over

  • tensorboard_log (Optional[str]) – the log location for tensorboard (if None, no logging)

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages

  • device (Union[device, str]) – Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible.

  • support_multi_env (bool) – Whether the algorithm supports training with multiple environments (as in A2C)

  • monitor_wrapper (bool) – When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • seed (Optional[int]) – Seed for the pseudo random generators

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • supported_action_spaces (Optional[Tuple[Space, ...]]) – The action spaces supported by the algorithm.

get_env()[source]

Returns the current environment (can be None if not defined).

Return type:

Optional[VecEnv]

Returns:

The current environment

get_parameters()[source]

Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).

Return type:

Dict[str, Dict]

Returns:

Mapping of from names of the objects to PyTorch state-dicts.

get_vec_normalize_env()[source]

Return the VecNormalize wrapper of the training env if it exists.

Return type:

Optional[VecNormalize]

Returns:

The VecNormalize env.

abstract learn(total_timesteps, callback=None, log_interval=100, tb_log_name='run', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps (int) – The total number of samples (env steps) to train on

  • callback (Union[None, Callable, List[BaseCallback], BaseCallback]) – callback(s) called at every step with state of the algorithm.

  • log_interval (int) – The number of episodes before logging.

  • tb_log_name (str) – the name of the run for TensorBoard logging

  • reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)

  • progress_bar (bool) – Display a progress bar using tqdm and rich.

Return type:

TypeVar(SelfBaseAlgorithm, bound= BaseAlgorithm)

Returns:

the trained model

classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)[source]

Load the model from a zip-file. Warning: load re-creates the model from scratch, it does not update it in-place! For an in-place load use set_parameters instead.

Parameters:
  • path (Union[str, Path, BufferedIOBase]) – path to the file (or a file-like) where to load the agent from

  • env (Union[Env, VecEnv, None]) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment

  • device (Union[device, str]) – Device on which the code should run.

  • custom_objects (Optional[Dict[str, Any]]) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in keras.models.load_model. Useful when you have an object in file that can not be deserialized.

  • print_system_info (bool) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)

  • force_reset (bool) – Force call to reset() before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597

  • kwargs – extra arguments to change the model when loading

Return type:

TypeVar(SelfBaseAlgorithm, bound= BaseAlgorithm)

Returns:

new model instance with loaded parameters

property logger: Logger

Getter for the logger object.

predict(observation, state=None, episode_start=None, deterministic=False)[source]

Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).

Parameters:
  • observation (Union[ndarray, Dict[str, ndarray]]) – the input observation

  • state (Optional[Tuple[ndarray, ...]]) – The last hidden states (can be None, used in recurrent policies)

  • episode_start (Optional[ndarray]) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.

  • deterministic (bool) – Whether or not to return deterministic actions.

Return type:

Tuple[ndarray, Optional[Tuple[ndarray, ...]]]

Returns:

the model’s action and the next hidden state (used in recurrent policies)

save(path, exclude=None, include=None)[source]

Save all the attributes of the object and the model parameters in a zip-file.

Parameters:
  • path (Union[str, Path, BufferedIOBase]) – path to the file where the rl agent should be saved

  • exclude (Optional[Iterable[str]]) – name of parameters that should be excluded in addition to the default ones

  • include (Optional[Iterable[str]]) – name of parameters that might be excluded but should be included anyway

Return type:

None

set_env(env, force_reset=True)[source]

Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space

Parameters:
Return type:

None

set_logger(logger)[source]

Setter for for logger object. :rtype: None

Warning

When passing a custom logger object, this will overwrite tensorboard_log and verbose settings passed to the constructor.

set_parameters(load_path_or_dict, exact_match=True, device='auto')[source]

Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see get_parameters).

Parameters:
  • load_path_or_iter – Location of the saved data (path or file-like, see save), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned by torch.nn.Module.state_dict().

  • exact_match (bool) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.

  • device (Union[device, str]) – Device on which the code should run.

Return type:

None

set_random_seed(seed=None)[source]

Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)

Parameters:

seed (Optional[int]) –

Return type:

None

Base Off-Policy Class

The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)

class stable_baselines3.common.off_policy_algorithm.OffPolicyAlgorithm(policy, env, learning_rate, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=(1, 'step'), gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, policy_kwargs=None, stats_window_size=100, tensorboard_log=None, verbose=0, device='auto', support_multi_env=False, monitor_wrapper=True, seed=None, use_sde=False, sde_sample_freq=-1, use_sde_at_warmup=False, sde_support=True, supported_action_spaces=None)[source]

The base for Off-Policy algorithms (ex: SAC/TD3)

Parameters:
  • policy (Union[str, Type[BasePolicy]]) – The policy model to use (MlpPolicy, CnnPolicy, …)

  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • learning_rate (Union[float, Callable[[float], float]]) – learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0)

  • buffer_size (int) – size of the replay buffer

  • learning_starts (int) – how many steps of the model to collect transitions for before learning starts

  • batch_size (int) – Minibatch size for each gradient update

  • tau (float) – the soft update coefficient (“Polyak update”, between 0 and 1)

  • gamma (float) – the discount factor

  • train_freq (Union[int, Tuple[int, str]]) – Update the model every train_freq steps. Alternatively pass a tuple of frequency and unit like (5, "step") or (2, "episode").

  • gradient_steps (int) – How many gradient steps to do after each rollout (see train_freq) Set to -1 means to do as many gradient steps as steps done in the environment during the rollout.

  • action_noise (Optional[ActionNoise]) – the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.

  • replay_buffer_class (Optional[Type[ReplayBuffer]]) – Replay buffer class to use (for instance HerReplayBuffer). If None, it will be automatically selected.

  • replay_buffer_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the replay buffer on creation.

  • optimize_memory_usage (bool) – Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195

  • policy_kwargs (Optional[Dict[str, Any]]) – Additional arguments to be passed to the policy on creation

  • stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over

  • tensorboard_log (Optional[str]) – the log location for tensorboard (if None, no logging)

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages

  • device (Union[device, str]) – Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible.

  • support_multi_env (bool) – Whether the algorithm supports training with multiple environments (as in A2C)

  • monitor_wrapper (bool) – When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • seed (Optional[int]) – Seed for the pseudo random generators

  • use_sde (bool) – Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • use_sde_at_warmup (bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)

  • sde_support (bool) – Whether the model support gSDE or not

  • supported_action_spaces (Optional[Tuple[Space, ...]]) – The action spaces supported by the algorithm.

collect_rollouts(env, callback, train_freq, replay_buffer, action_noise=None, learning_starts=0, log_interval=None)[source]

Collect experiences and store them into a ReplayBuffer.

Parameters:
  • env (VecEnv) – The training environment

  • callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)

  • train_freq (TrainFreq) – How much experience to collect by doing rollouts of current policy. Either TrainFreq(<n>, TrainFrequencyUnit.STEP) or TrainFreq(<n>, TrainFrequencyUnit.EPISODE) with <n> being an integer greater than 0.

  • action_noise (Optional[ActionNoise]) – Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.

  • learning_starts (int) – Number of steps before learning for the warm-up phase.

  • replay_buffer (ReplayBuffer) –

  • log_interval (Optional[int]) – Log data every log_interval episodes

Return type:

RolloutReturn

Returns:

learn(total_timesteps, callback=None, log_interval=4, tb_log_name='run', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps (int) – The total number of samples (env steps) to train on

  • callback (Union[None, Callable, List[BaseCallback], BaseCallback]) – callback(s) called at every step with state of the algorithm.

  • log_interval (int) – The number of episodes before logging.

  • tb_log_name (str) – the name of the run for TensorBoard logging

  • reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)

  • progress_bar (bool) – Display a progress bar using tqdm and rich.

Return type:

TypeVar(SelfOffPolicyAlgorithm, bound= OffPolicyAlgorithm)

Returns:

the trained model

load_replay_buffer(path, truncate_last_traj=True)[source]

Load a replay buffer from a pickle file.

Parameters:
  • path (Union[str, Path, BufferedIOBase]) – Path to the pickled replay buffer.

  • truncate_last_traj (bool) – When using HerReplayBuffer with online sampling: If set to True, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set to False, we assume that we continue the same trajectory (same episode).

Return type:

None

save_replay_buffer(path)[source]

Save the replay buffer as a pickle file.

Parameters:

path (Union[str, Path, BufferedIOBase]) – Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.

Return type:

None

train(gradient_steps, batch_size)[source]

Sample the replay buffer and do the updates (gradient descent and update target networks)

Return type:

None

Base On-Policy Class

The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)

class stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm(policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, stats_window_size=100, tensorboard_log=None, monitor_wrapper=True, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True, supported_action_spaces=None)[source]

The base for On-Policy algorithms (ex: A2C/PPO).

Parameters:
  • policy (Union[str, Type[ActorCriticPolicy]]) – The policy model to use (MlpPolicy, CnnPolicy, …)

  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str)

  • learning_rate (Union[float, Callable[[float], float]]) – The learning rate, it can be a function of the current progress remaining (from 1 to 0)

  • n_steps (int) – The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)

  • gamma (float) – Discount factor

  • gae_lambda (float) – Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1.

  • ent_coef (float) – Entropy coefficient for the loss calculation

  • vf_coef (float) – Value function coefficient for the loss calculation

  • max_grad_norm (float) – The maximum value for the gradient clipping

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • stats_window_size (int) – Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over

  • tensorboard_log (Optional[str]) – the log location for tensorboard (if None, no logging)

  • monitor_wrapper (bool) – When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • policy_kwargs (Optional[Dict[str, Any]]) – additional arguments to be passed to the policy on creation

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages

  • seed (Optional[int]) – Seed for the pseudo random generators

  • device (Union[device, str]) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.

  • _init_setup_model (bool) – Whether or not to build the network at the creation of the instance

  • supported_action_spaces (Optional[Tuple[Space, ...]]) – The action spaces supported by the algorithm.

collect_rollouts(env, callback, rollout_buffer, n_rollout_steps)[source]

Collect experiences using the current policy and fill a RolloutBuffer. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning.

Parameters:
  • env (VecEnv) – The training environment

  • callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)

  • rollout_buffer (RolloutBuffer) – Buffer to fill with rollouts

  • n_rollout_steps (int) – Number of experiences to collect per environment

Return type:

bool

Returns:

True if function returned with at least n_rollout_steps collected, False if callback terminated rollout prematurely.

learn(total_timesteps, callback=None, log_interval=1, tb_log_name='OnPolicyAlgorithm', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps (int) – The total number of samples (env steps) to train on

  • callback (Union[None, Callable, List[BaseCallback], BaseCallback]) – callback(s) called at every step with state of the algorithm.

  • log_interval (int) – The number of episodes before logging.

  • tb_log_name (str) – the name of the run for TensorBoard logging

  • reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)

  • progress_bar (bool) – Display a progress bar using tqdm and rich.

Return type:

TypeVar(SelfOnPolicyAlgorithm, bound= OnPolicyAlgorithm)

Returns:

the trained model

train()[source]

Consume current rollout data and update policy parameters. Implemented by individual algorithms.

Return type:

None