import os
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
import gym
import numpy as np
from stable_baselines3.common import base_class # pytype: disable=pyi-error
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
[docs]class BaseCallback(ABC):
"""
Base class for callback.
:param verbose:
"""
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
# The RL model
self.model = None # type: Optional[base_class.BaseAlgorithm]
# An alias for self.model.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals: Dict[str, Any] = {}
self.globals: Dict[str, Any] = {}
self.logger = None
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]
# Type hint as string to avoid circular import
[docs] def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
"""
self.model = model
self.training_env = model.get_env()
self.logger = model.logger
self._init_callback()
def _init_callback(self) -> None:
pass
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
self._on_training_start()
def _on_training_start(self) -> None:
pass
def on_rollout_start(self) -> None:
self._on_rollout_start()
def _on_rollout_start(self) -> None:
pass
@abstractmethod
def _on_step(self) -> bool:
"""
:return: If the callback returns False, training is aborted early.
"""
return True
[docs] def on_step(self) -> bool:
"""
This method will be called by the model after each call to ``env.step()``.
For child callback (of an ``EventCallback``), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
# timesteps start at zero
self.num_timesteps = self.model.num_timesteps
return self._on_step()
def on_training_end(self) -> None:
self._on_training_end()
def _on_training_end(self) -> None:
pass
def on_rollout_end(self) -> None:
self._on_rollout_end()
def _on_rollout_end(self) -> None:
pass
[docs] def update_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
self.locals.update(locals_)
self.update_child_locals(locals_)
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables on sub callbacks.
:param locals_: the local variables during rollout collection
"""
pass
[docs]class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called
when an event is triggered.
:param verbose:
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super(EventCallback, self).__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
[docs] def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super(EventCallback, self).init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
def _on_training_start(self) -> None:
if self.callback is not None:
self.callback.on_training_start(self.locals, self.globals)
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback.on_step()
return True
def _on_step(self) -> bool:
return True
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback is not None:
self.callback.update_locals(locals_)
[docs]class CallbackList(BaseCallback):
"""
Class for chaining callbacks.
:param callbacks: A list of callbacks that will be called
sequentially.
"""
def __init__(self, callbacks: List[BaseCallback]):
super(CallbackList, self).__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.model)
def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
def _on_rollout_start(self) -> None:
for callback in self.callbacks:
callback.on_rollout_start()
def _on_step(self) -> bool:
continue_training = True
for callback in self.callbacks:
# Return False (stop training) if at least one callback returns False
continue_training = callback.on_step() and continue_training
return continue_training
def _on_rollout_end(self) -> None:
for callback in self.callbacks:
callback.on_rollout_end()
def _on_training_end(self) -> None:
for callback in self.callbacks:
callback.on_training_end()
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
for callback in self.callbacks:
callback.update_locals(locals_)
[docs]class CheckpointCallback(BaseCallback):
"""
Callback for saving a model every ``save_freq`` calls
to ``env.step()``.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
:param save_freq:
:param save_path: Path to the folder where the model will be saved.
:param name_prefix: Common prefix to the saved models
:param verbose:
"""
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
super(CheckpointCallback, self).__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.save_freq == 0:
path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
self.model.save(path)
if self.verbose > 1:
print(f"Saving model checkpoint to {path}")
return True
[docs]class ConvertCallback(BaseCallback):
"""
Convert functional callback (old-style) to object.
:param callback:
:param verbose:
"""
def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
super(ConvertCallback, self).__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
if self.callback is not None:
return self.callback(self.locals, self.globals)
return True
[docs]class EvalCallback(EventCallback):
"""
Callback for evaluating an agent.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
will be saved. It will be updated at each evaluation.
:param best_model_save_path: Path to a folder where the best model
according to performance on the eval env will be saved.
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param render: Whether to render or not the environment during evaluation
:param verbose:
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""
def __init__(
self,
eval_env: Union[gym.Env, VecEnv],
callback_on_new_best: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: Optional[str] = None,
best_model_save_path: Optional[str] = None,
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
self.last_mean_reward = -np.inf
self.deterministic = deterministic
self.render = render
self.warn = warn
# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env])
self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
if log_path is not None:
log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results = []
self.evaluations_timesteps = []
self.evaluations_length = []
# For computing success rate
self._is_success_buffer = []
self.evaluations_successes = []
def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
if not isinstance(self.training_env, type(self.eval_env)):
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
# Create folders if needed
if self.best_model_save_path is not None:
os.makedirs(self.best_model_save_path, exist_ok=True)
if self.log_path is not None:
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
in order to log the success rate (when applicable),
for instance when using HER.
:param locals_:
:param globals_:
"""
info = locals_["info"]
if locals_["done"]:
maybe_is_success = info.get("is_success")
if maybe_is_success is not None:
self._is_success_buffer.append(maybe_is_success)
def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
)
# Reset success rate buffer
self._is_success_buffer = []
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
callback=self._log_success_callback,
)
if self.log_path is not None:
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
kwargs = {}
# Save success log if present
if len(self._is_success_buffer) > 0:
self.evaluations_successes.append(self._is_success_buffer)
kwargs = dict(successes=self.evaluations_successes)
np.savez(
self.log_path,
timesteps=self.evaluations_timesteps,
results=self.evaluations_results,
ep_lengths=self.evaluations_length,
**kwargs,
)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
self.logger.record("eval/mean_reward", float(mean_reward))
self.logger.record("eval/mean_ep_length", mean_ep_length)
if len(self._is_success_buffer) > 0:
success_rate = np.mean(self._is_success_buffer)
if self.verbose > 0:
print(f"Success rate: {100 * success_rate:.2f}%")
self.logger.record("eval/success_rate", success_rate)
# Dump log so the evaluation results are printed with the correct timestep
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(self.num_timesteps)
if mean_reward > self.best_mean_reward:
if self.verbose > 0:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
# Trigger callback if needed
if self.callback is not None:
return self._on_event()
return True
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback:
self.callback.update_locals(locals_)
[docs]class StopTrainingOnRewardThreshold(BaseCallback):
"""
Stop the training once a threshold in episodic reward
has been reached (i.e. when the model is good enough).
It must be used with the ``EvalCallback``.
:param reward_threshold: Minimum expected reward per episode
to stop training.
:param verbose:
"""
def __init__(self, reward_threshold: float, verbose: int = 0):
super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose > 0 and not continue_training:
print(
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
f" is above the threshold {self.reward_threshold}"
)
return continue_training
[docs]class EveryNTimesteps(EventCallback):
"""
Trigger a callback every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
:param callback: Callback that will be called
when the event is triggered.
"""
def __init__(self, n_steps: int, callback: BaseCallback):
super(EveryNTimesteps, self).__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
def _on_step(self) -> bool:
if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
self.last_time_trigger = self.num_timesteps
return self._on_event()
return True
[docs]class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
and in total for ``max_episodes * n_envs`` episodes.
:param max_episodes: Maximum number of episodes to stop training.
:param verbose: Select whether to print information about when training ended by reaching ``max_episodes``
"""
def __init__(self, max_episodes: int, verbose: int = 0):
super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
def _init_callback(self) -> None:
# At start set total max according to number of envirnments
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
def _on_step(self) -> bool:
# Check that the `dones` local variable is defined
assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
self.n_episodes += np.sum(self.locals["dones"]).item()
continue_training = self.n_episodes < self._total_max_episodes
if self.verbose > 0 and not continue_training:
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
mean_ep_str = (
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
)
print(
f"Stopping training with a total of {self.num_timesteps} steps because the "
f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
f"by playing for {self.n_episodes} episodes "
f"{mean_ep_str}"
)
return continue_training