Callbacks

A callback is a set of functions that will be called at given stages of the training procedure. You can use callbacks to access internal state of the RL model during training. It allows one to do monitoring, auto saving, model manipulation, progress bars, …

Custom Callback

To build a custom callback, you need to create a class that derives from BaseCallback. This will give you access to events (_on_training_start, _on_step) and useful variables (like self.model for the RL model).

You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see Examples), and one for logging additional values with Tensorboard (see Tensorboard section).

from stable_baselines3.common.callbacks import BaseCallback


class CustomCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
    """
    def __init__(self, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseAlgorithm
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = None  # type: Dict[str, Any]
        # self.globals = None  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger = None  # stable_baselines3.common.logger
        # # Sometimes, for event callback, it is useful
        # # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        """
        pass

    def _on_rollout_start(self) -> None:
        """
        A rollout is the collection of environment interaction
        using the current policy.
        This event is triggered before collecting new samples.
        """
        pass

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        return True

    def _on_rollout_end(self) -> None:
        """
        This event is triggered before updating the policy.
        """
        pass

    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        """
        pass

Note

self.num_timesteps corresponds to the total number of steps taken in the environment, i.e., it is the number of environments multiplied by the number of time env.step() was called

For the other algorithms, self.num_timesteps is incremented by n_envs (number of environments) after each call to env.step()

Note

For off-policy algorithms like SAC, DDPG, TD3 or DQN, the notion of rollout corresponds to the steps taken in the environment between two updates.

Event Callback

Compared to Keras, Stable Baselines provides a second type of BaseCallback, named EventCallback that is meant to trigger events. When an event is triggered, then a child callback is called.

As an example, EvalCallback is an EventCallback that will trigger its child callback when there is a new best model. A child callback is for instance StopTrainingOnRewardThreshold that stops the training if the mean reward achieved by the RL model is above a threshold.

Note

We recommend to take a look at the source code of EvalCallback and StopTrainingOnRewardThreshold to have a better overview of what can be achieved with this kind of callbacks.

class EventCallback(BaseCallback):
    """
    Base class for triggering callback on event.

    :param callback: (Optional[BaseCallback]) Callback that will be called
        when an event is triggered.
    :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
    """
    def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
        super(EventCallback, self).__init__(verbose=verbose)
        self.callback = callback
        # Give access to the parent
        if callback is not None:
            self.callback.parent = self
    ...

    def _on_event(self) -> bool:
        if self.callback is not None:
            return self.callback()
        return True

Callback Collection

Stable Baselines provides you with a set of common callbacks for:

CheckpointCallback

Callback for saving a model every save_freq calls to env.step(), you must specify a log folder (save_path) and optionally a prefix for the checkpoints (rl_model by default). If you are using this callback to stop and resume training, you may want to optionally save the replay buffer if the model has one (save_replay_buffer, False by default). Additionally, if your environment uses a VecNormalize wrapper, you can save the corresponding statistics using save_vecnormalize (False by default).

Warning

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. If you want the save_freq to be similar when using different number of environments, you need to account for it using save_freq = max(save_freq // n_envs, 1). The same goes for the other callbacks.

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback

# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(
  save_freq=1000,
  save_path="./logs/",
  name_prefix="rl_model",
  save_replay_buffer=True,
  save_vecnormalize=True,
)

model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(2000, callback=checkpoint_callback)

EvalCallback

Evaluate periodically the performance of an agent, using a separate test environment. It will save the best model if best_model_save_path folder is specified and save the evaluations results in a numpy archive (evaluations.npz) if log_path folder is specified.

Note

You can pass child callbacks via callback_after_eval and callback_on_new_best arguments. callback_after_eval will be triggered after every evaluation, and callback_on_new_best will be triggered each time there is a new best model.

Warning

You need to make sure that eval_env is wrapped the same way as the training environment, for instance using the VecTransposeImage wrapper if you have a channel-last image as input. The EvalCallback class outputs a warning if it is not the case.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback

# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
                             log_path="./logs/", eval_freq=500,
                             deterministic=True, render=False)

model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)

ProgressBarCallback

Display a progress bar with the current progress, elapsed time and estimated remaining time. This callback is integrated inside SB3 via the progress_bar argument of the learn() method.

Note

This callback requires tqdm and rich packages to be installed. This is done automatically when using pip install stable-baselines3[extra]

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import ProgressBarCallback

model = PPO("MlpPolicy", "Pendulum-v1")
# Display progress bar using the progress bar callback
# this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
model.learn(100_000, progress_bar=True)

CallbackList

Class for chaining callbacks, they will be called sequentially. Alternatively, you can pass directly a list of callbacks to the learn() method, it will be converted automatically to a CallbackList.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
                             log_path="./logs/results", eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])

model = SAC("MlpPolicy", "Pendulum-v1")
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)

StopTrainingOnRewardThreshold

Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough). It must be used with the EvalCallback and use the event triggered by a new best model.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)

model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)

EveryNTimesteps

An Event Callback that will trigger its child callback every n_steps timesteps.

Note

Because of the way PPO1 and TRPO work (they rely on MPI), n_steps is a lower bound between two events.

import gym

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps

# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)

model.learn(int(2e4), callback=event_callback)

StopTrainingOnMaxEpisodes

Stop the training upon reaching the maximum number of episodes, regardless of the model’s total_timesteps value. Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for max_episodes and in total for max_episodes * n_envs episodes.

Note

For multiple environments, the agent will train for a total of max_episodes * n_envs episodes. However, it can’t be guaranteed that this training will occur for an exact number of max_episodes per environment. Thus, there is an assumption that, on average, each environment ran for max_episodes.

from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes

# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)

model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)

StopTrainingOnNoModelImprovement

Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations. The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore, after many evaluations without improvement the learning has probably stabilized. It must be used with the EvalCallback and use the event triggered after every evaluation.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement

# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)

model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
class stable_baselines3.common.callbacks.BaseCallback(verbose=0)[source]

Base class for callback.

Parameters:

verbose (int) – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

init_callback(model)[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

Return type:

None

on_step()[source]

This method will be called by the model after each call to env.step().

For child callback (of an EventCallback), this will be called when the event is triggered.

Return type:

bool

Returns:

If the callback returns False, training is aborted early.

update_child_locals(locals_)[source]

Update the references to the local variables on sub callbacks.

Parameters:

locals – the local variables during rollout collection

Return type:

None

update_locals(locals_)[source]

Update the references to the local variables.

Parameters:

locals – the local variables during rollout collection

Return type:

None

class stable_baselines3.common.callbacks.CallbackList(callbacks)[source]

Class for chaining callbacks.

Parameters:

callbacks (List[BaseCallback]) – A list of callbacks that will be called sequentially.

update_child_locals(locals_)[source]

Update the references to the local variables.

Parameters:

locals – the local variables during rollout collection

Return type:

None

class stable_baselines3.common.callbacks.CheckpointCallback(save_freq, save_path, name_prefix='rl_model', save_replay_buffer=False, save_vecnormalize=False, verbose=0)[source]

Callback for saving a model every save_freq calls to env.step(). By default, it only saves model checkpoints, you need to pass save_replay_buffer=True, and save_vecnormalize=True to also save replay buffer checkpoints and normalization statistics checkpoints.

Warning

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use save_freq = max(save_freq // n_envs, 1)

Parameters:
  • save_freq (int) – Save checkpoints every save_freq call of the callback.

  • save_path (str) – Path to the folder where the model will be saved.

  • name_prefix (str) – Common prefix to the saved models

  • save_replay_buffer (bool) – Save the model replay buffer

  • save_vecnormalize (bool) – Save the VecNormalize statistics

  • verbose (int) – Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint

class stable_baselines3.common.callbacks.ConvertCallback(callback, verbose=0)[source]

Convert functional callback (old-style) to object.

Parameters:
  • callback (Callable[[Dict[str, Any], Dict[str, Any]], bool]) –

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

class stable_baselines3.common.callbacks.EvalCallback(eval_env, callback_on_new_best=None, callback_after_eval=None, n_eval_episodes=5, eval_freq=10000, log_path=None, best_model_save_path=None, deterministic=True, render=False, verbose=1, warn=True)[source]

Callback for evaluating an agent.

Warning

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use eval_freq = max(eval_freq // n_envs, 1)

Parameters:
  • eval_env (Union[Env, VecEnv]) – The environment used for initialization

  • callback_on_new_best (Optional[BaseCallback]) – Callback to trigger when there is a new best model according to the mean_reward

  • callback_after_eval (Optional[BaseCallback]) – Callback to trigger after every evaluation

  • n_eval_episodes (int) – The number of episodes to test the agent

  • eval_freq (int) – Evaluate the agent every eval_freq call of the callback.

  • log_path (Optional[str]) – Path to a folder where the evaluations (evaluations.npz) will be saved. It will be updated at each evaluation.

  • best_model_save_path (Optional[str]) – Path to a folder where the best model according to performance on the eval env will be saved.

  • deterministic (bool) – Whether the evaluation should use a stochastic or deterministic actions.

  • render (bool) – Whether to render or not the environment during evaluation

  • verbose (int) – Verbosity level: 0 for no output, 1 for indicating information about evaluation results

  • warn (bool) – Passed to evaluate_policy (warns if eval_env has not been wrapped with a Monitor wrapper)

update_child_locals(locals_)[source]

Update the references to the local variables.

Parameters:

locals – the local variables during rollout collection

Return type:

None

class stable_baselines3.common.callbacks.EventCallback(callback=None, verbose=0)[source]

Base class for triggering callback on event.

Parameters:
  • callback (Optional[BaseCallback]) – Callback that will be called when an event is triggered.

  • verbose (int) – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

init_callback(model)[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

Return type:

None

update_child_locals(locals_)[source]

Update the references to the local variables.

Parameters:

locals – the local variables during rollout collection

Return type:

None

class stable_baselines3.common.callbacks.EveryNTimesteps(n_steps, callback)[source]

Trigger a callback every n_steps timesteps

Parameters:
  • n_steps (int) – Number of timesteps between two trigger.

  • callback (BaseCallback) – Callback that will be called when the event is triggered.

class stable_baselines3.common.callbacks.ProgressBarCallback[source]

Display a progress bar when training SB3 agent using tqdm and rich packages.

class stable_baselines3.common.callbacks.StopTrainingOnMaxEpisodes(max_episodes, verbose=0)[source]

Stop the training once a maximum number of episodes are played.

For multiple environments presumes that, the desired behavior is that the agent trains on each env for max_episodes and in total for max_episodes * n_envs episodes.

Parameters:
  • max_episodes (int) – Maximum number of episodes to stop training.

  • verbose (int) – Verbosity level: 0 for no output, 1 for indicating information about when training ended by reaching max_episodes

class stable_baselines3.common.callbacks.StopTrainingOnNoModelImprovement(max_no_improvement_evals, min_evals=0, verbose=0)[source]

Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.

It is possible to define a minimum number of evaluations before start to count evaluations without improvement.

It must be used with the EvalCallback.

Parameters:
  • max_no_improvement_evals (int) – Maximum number of consecutive evaluations without a new best model.

  • min_evals (int) – Number of evaluations before start to count evaluations without improvements.

  • verbose (int) – Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model

class stable_baselines3.common.callbacks.StopTrainingOnRewardThreshold(reward_threshold, verbose=0)[source]

Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough).

It must be used with the EvalCallback.

Parameters:
  • reward_threshold (float) – Minimum expected reward per episode to stop training.

  • verbose (int) – Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward threshold reached