Abstract base classes for RL algorithms.

Base RL Class

Common interface for all the RL algorithms

class stable_baselines3.common.base_class.BaseAlgorithm(policy: Type[stable_baselines3.common.policies.BasePolicy], env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv, str]], policy_base: Type[stable_baselines3.common.policies.BasePolicy], learning_rate: Union[float, Callable], policy_kwargs: Dict[str, Any] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, device: Union[torch.device, str] = 'auto', support_multi_env: bool = False, create_eval_env: bool = False, monitor_wrapper: bool = True, seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = - 1)[source]

The base of RL algorithms

Parameters
  • policy – (Type[BasePolicy]) Policy object

  • env – (Union[GymEnv, str, None]) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • policy_base – (Type[BasePolicy]) The base policy used by this method

  • learning_rate – (float or callable) learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0)

  • policy_kwargs – (Dict[str, Any]) Additional arguments to be passed to the policy on creation

  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)

  • verbose – (int) The verbosity level: 0 none, 1 training information, 2 debug

  • device – (Union[th.device, str]) Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible.

  • support_multi_env – (bool) Whether the algorithm supports training with multiple environments (as in A2C)

  • create_eval_env – (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)

  • monitor_wrapper – (bool) When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • seed – (Optional[int]) Seed for the pseudo random generators

  • use_sde – (bool) Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq – (int) Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

excluded_save_params() → List[str][source]

Returns the names of the parameters that should be excluded by default when saving the model.

Returns

([str]) List of parameters that should be excluded from save

get_env() → Optional[stable_baselines3.common.vec_env.base_vec_env.VecEnv][source]

Returns the current environment (can be None if not defined).

Returns

(Optional[VecEnv]) The current environment

get_torch_variables() → Tuple[List[str], List[str]][source]

Get the name of the torch variables that will be saved. th.save and th.load will be used with the right device instead of the default pickling strategy.

Returns

(Tuple[List[str], List[str]]) name of the variables with state dicts to save, name of additional torch tensors,

get_vec_normalize_env() → Optional[stable_baselines3.common.vec_env.vec_normalize.VecNormalize][source]

Return the VecNormalize wrapper of the training env if it exists. :return: Optional[VecNormalize] The VecNormalize env.

abstract learn(total_timesteps: int, callback: Union[None, Callable, List[stable_baselines3.common.callbacks.BaseCallback], stable_baselines3.common.callbacks.BaseCallback] = None, log_interval: int = 100, tb_log_name: str = 'run', eval_env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, eval_freq: int = - 1, n_eval_episodes: int = 5, eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True)stable_baselines3.common.base_class.BaseAlgorithm[source]

Return a trained model.

Parameters
  • total_timesteps – (int) The total number of samples (env steps) to train on

  • callback – (MaybeCallback) callback(s) called at every step with state of the algorithm.

  • log_interval – (int) The number of timesteps before logging.

  • tb_log_name – (str) the name of the run for TensorBoard logging

  • eval_env – (gym.Env) Environment that will be used to evaluate the agent

  • eval_freq – (int) Evaluate the agent every eval_freq timesteps (this may vary a little)

  • n_eval_episodes – (int) Number of episode to evaluate the agent

  • eval_log_path – (Optional[str]) Path to a folder where the evaluations will be saved

  • reset_num_timesteps – (bool) whether or not to reset the current timestep number (used in logging)

Returns

(BaseAlgorithm) the trained model

classmethod load(load_path: str, env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, **kwargs)stable_baselines3.common.base_class.BaseAlgorithm[source]

Load the model from a zip-file

Parameters
  • load_path – the location of the saved data

  • env – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment

  • kwargs – extra arguments to change the model when loading

predict(observation: numpy.ndarray, state: Optional[numpy.ndarray] = None, mask: Optional[numpy.ndarray] = None, deterministic: bool = False) → Tuple[numpy.ndarray, Optional[numpy.ndarray]][source]

Get the model’s action(s) from an observation

Parameters
  • observation – (np.ndarray) the input observation

  • state – (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)

  • mask – (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)

  • deterministic – (bool) Whether or not to return deterministic actions.

Returns

(Tuple[np.ndarray, Optional[np.ndarray]]) the model’s action and the next state (used in recurrent policies)

save(path: Union[str, pathlib.Path, io.BufferedIOBase], exclude: Optional[Iterable[str]] = None, include: Optional[Iterable[str]] = None) → None[source]

Save all the attributes of the object and the model parameters in a zip-file.

Parameters
  • pathlib.Path, io.BufferedIOBase]) ((Union[str,) – path to the file where the rl agent should be saved

  • exclude – name of parameters that should be excluded in addition to the default one

  • include – name of parameters that might be excluded but should be included anyway

set_env(env: Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]) → None[source]

Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space

Parameters

env – The environment for learning a policy

set_random_seed(seed: Optional[int] = None) → None[source]

Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)

Parameters

seed – (int)

Base Off-Policy Class

The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)

class stable_baselines3.common.off_policy_algorithm.OffPolicyAlgorithm(policy: Type[stable_baselines3.common.policies.BasePolicy], env: Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv, str], policy_base: Type[stable_baselines3.common.policies.BasePolicy], learning_rate: Union[float, Callable], buffer_size: int = 1000000, learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int = 1, gradient_steps: int = 1, n_episodes_rollout: int = - 1, action_noise: Optional[stable_baselines3.common.noise.ActionNoise] = None, optimize_memory_usage: bool = False, policy_kwargs: Dict[str, Any] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, device: Union[torch.device, str] = 'auto', support_multi_env: bool = False, create_eval_env: bool = False, monitor_wrapper: bool = True, seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = - 1, use_sde_at_warmup: bool = False, sde_support: bool = True)[source]

The base for Off-Policy algorithms (ex: SAC/TD3)

Parameters
  • policy – Policy object

  • env – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • policy_base – The base policy used by this method

  • learning_rate – (float or callable) learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0)

  • buffer_size – (int) size of the replay buffer

  • learning_starts – (int) how many steps of the model to collect transitions for before learning starts

  • batch_size – (int) Minibatch size for each gradient update

  • tau – (float) the soft update coefficient (“Polyak update”, between 0 and 1)

  • gamma – (float) the discount factor

  • train_freq – (int) Update the model every train_freq steps. Set to -1 to disable.

  • gradient_steps – (int) How many gradient steps to do after each rollout (see train_freq and n_episodes_rollout) Set to -1 means to do as many gradient steps as steps done in the environment during the rollout.

  • n_episodes_rollout – (int) Update the model every n_episodes_rollout episodes. Note that this cannot be used at the same time as train_freq. Set to -1 to disable.

  • action_noise – (ActionNoise) the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.

  • optimize_memory_usage – (bool) Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195

  • policy_kwargs – Additional arguments to be passed to the policy on creation

  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)

  • verbose – The verbosity level: 0 none, 1 training information, 2 debug

  • device – Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible.

  • support_multi_env – Whether the algorithm supports training with multiple environments (as in A2C)

  • create_eval_env – Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)

  • monitor_wrapper – When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • seed – Seed for the pseudo random generators

  • use_sde – Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False)

  • sde_sample_freq – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • use_sde_at_warmup – (bool) Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)

  • sde_support – (bool) Whether the model support gSDE or not

collect_rollouts(env: stable_baselines3.common.vec_env.base_vec_env.VecEnv, callback: stable_baselines3.common.callbacks.BaseCallback, n_episodes: int = 1, n_steps: int = - 1, action_noise: Optional[stable_baselines3.common.noise.ActionNoise] = None, learning_starts: int = 0, replay_buffer: Optional[stable_baselines3.common.buffers.ReplayBuffer] = None, log_interval: Optional[int] = None) → stable_baselines3.common.type_aliases.RolloutReturn[source]

Collect experiences and store them into a ReplayBuffer.

Parameters
  • env – (VecEnv) The training environment

  • callback – (BaseCallback) Callback that will be called at each step (and at the beginning and end of the rollout)

  • n_episodes – (int) Number of episodes to use to collect rollout data You can also specify a n_steps instead

  • n_steps – (int) Number of steps to use to collect rollout data You can also specify a n_episodes instead.

  • action_noise – (Optional[ActionNoise]) Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.

  • learning_starts – (int) Number of steps before learning for the warm-up phase.

  • replay_buffer – (ReplayBuffer)

  • log_interval – (int) Log data every log_interval episodes

Returns

(RolloutReturn)

learn(total_timesteps: int, callback: Union[None, Callable, List[stable_baselines3.common.callbacks.BaseCallback], stable_baselines3.common.callbacks.BaseCallback] = None, log_interval: int = 4, eval_env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, eval_freq: int = - 1, n_eval_episodes: int = 5, tb_log_name: str = 'run', eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True)stable_baselines3.common.off_policy_algorithm.OffPolicyAlgorithm[source]

Return a trained model.

Parameters
  • total_timesteps – (int) The total number of samples (env steps) to train on

  • callback – (MaybeCallback) callback(s) called at every step with state of the algorithm.

  • log_interval – (int) The number of timesteps before logging.

  • tb_log_name – (str) the name of the run for TensorBoard logging

  • eval_env – (gym.Env) Environment that will be used to evaluate the agent

  • eval_freq – (int) Evaluate the agent every eval_freq timesteps (this may vary a little)

  • n_eval_episodes – (int) Number of episode to evaluate the agent

  • eval_log_path – (Optional[str]) Path to a folder where the evaluations will be saved

  • reset_num_timesteps – (bool) whether or not to reset the current timestep number (used in logging)

Returns

(BaseAlgorithm) the trained model

load_replay_buffer(path: Union[str, pathlib.Path, io.BufferedIOBase]) → None[source]

Load a replay buffer from a pickle file.

Parameters

path – (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.

save_replay_buffer(path: Union[str, pathlib.Path, io.BufferedIOBase]) → None[source]

Save the replay buffer as a pickle file.

Parameters

path – (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.

train(gradient_steps: int, batch_size: int) → None[source]

Sample the replay buffer and do the updates (gradient descent and update target networks)

Base On-Policy Class

The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)

class stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm(policy: Union[str, Type[stable_baselines3.common.policies.ActorCriticPolicy]], env: Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv, str], learning_rate: Union[float, Callable], n_steps: int, gamma: float, gae_lambda: float, ent_coef: float, vf_coef: float, max_grad_norm: float, use_sde: bool, sde_sample_freq: int, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, monitor_wrapper: bool = True, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[torch.device, str] = 'auto', _init_setup_model: bool = True)[source]

The base for On-Policy algorithms (ex: A2C/PPO).

Parameters
  • policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, …)

  • env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)

  • learning_rate – (float or callable) The learning rate, it can be a function of the current progress remaining (from 1 to 0)

  • n_steps – (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)

  • gamma – (float) Discount factor

  • gae_lambda – (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1.

  • ent_coef – (float) Entropy coefficient for the loss calculation

  • vf_coef – (float) Value function coefficient for the loss calculation

  • max_grad_norm – (float) The maximum value for the gradient clipping

  • use_sde – (bool) Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq – (int) Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)

  • create_eval_env – (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)

  • monitor_wrapper – When creating an environment, whether to wrap it or not in a Monitor wrapper.

  • policy_kwargs – (dict) additional arguments to be passed to the policy on creation

  • verbose – (int) the verbosity level: 0 no output, 1 info, 2 debug

  • seed – (int) Seed for the pseudo random generators

  • device – (str or th.device) Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.

  • _init_setup_model – (bool) Whether or not to build the network at the creation of the instance

collect_rollouts(env: stable_baselines3.common.vec_env.base_vec_env.VecEnv, callback: stable_baselines3.common.callbacks.BaseCallback, rollout_buffer: stable_baselines3.common.buffers.RolloutBuffer, n_rollout_steps: int) → bool[source]

Collect rollouts using the current policy and fill a RolloutBuffer.

Parameters
  • env – (VecEnv) The training environment

  • callback – (BaseCallback) Callback that will be called at each step (and at the beginning and end of the rollout)

  • rollout_buffer – (RolloutBuffer) Buffer to fill with rollouts

  • n_steps – (int) Number of experiences to collect per environment

Returns

(bool) True if function returned with at least n_rollout_steps collected, False if callback terminated rollout prematurely.

get_torch_variables() → Tuple[List[str], List[str]][source]

cf base class

learn(total_timesteps: int, callback: Union[None, Callable, List[stable_baselines3.common.callbacks.BaseCallback], stable_baselines3.common.callbacks.BaseCallback] = None, log_interval: int = 1, eval_env: Optional[Union[gym.core.Env, stable_baselines3.common.vec_env.base_vec_env.VecEnv]] = None, eval_freq: int = - 1, n_eval_episodes: int = 5, tb_log_name: str = 'OnPolicyAlgorithm', eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True)stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm[source]

Return a trained model.

Parameters
  • total_timesteps – (int) The total number of samples (env steps) to train on

  • callback – (MaybeCallback) callback(s) called at every step with state of the algorithm.

  • log_interval – (int) The number of timesteps before logging.

  • tb_log_name – (str) the name of the run for TensorBoard logging

  • eval_env – (gym.Env) Environment that will be used to evaluate the agent

  • eval_freq – (int) Evaluate the agent every eval_freq timesteps (this may vary a little)

  • n_eval_episodes – (int) Number of episode to evaluate the agent

  • eval_log_path – (Optional[str]) Path to a folder where the evaluations will be saved

  • reset_num_timesteps – (bool) whether or not to reset the current timestep number (used in logging)

Returns

(BaseAlgorithm) the trained model

train() → None[source]

Consume current rollout data and update policy parameters. Implemented by individual algorithms.