import glob
import os
import platform
import random
import re
import warnings
from collections import deque
from collections.abc import Iterable
import cloudpickle
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
import stable_baselines3 as sb3
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
[docs]
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
[docs]
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)
[docs]
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
[docs]
class FloatSchedule:
"""
Wrapper that ensures the output of a Schedule is cast to float.
Can wrap either a constant value or an existing callable Schedule.
:param value_schedule: Constant value or callable schedule
(e.g. LinearSchedule, ConstantSchedule)
"""
def __init__(self, value_schedule: Schedule | float):
if isinstance(value_schedule, FloatSchedule):
self.value_schedule: Schedule = value_schedule.value_schedule
elif isinstance(value_schedule, (float, int)):
self.value_schedule = ConstantSchedule(float(value_schedule))
else:
assert callable(value_schedule), f"The learning rate schedule must be a float or a callable, not {value_schedule}"
self.value_schedule = value_schedule
def __call__(self, progress_remaining: float) -> float:
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return float(self.value_schedule(progress_remaining))
def __repr__(self) -> str:
return f"FloatSchedule({self.value_schedule})"
[docs]
class LinearSchedule:
"""
LinearSchedule interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:param start: value to start with if ``progress_remaining`` = 1
:param end: value to end with if ``progress_remaining`` = 0
:param end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1
then end is reached after 10% of the complete training process.
"""
def __init__(self, start: float, end: float, end_fraction: float) -> None:
self.start = start
self.end = end
self.end_fraction = end_fraction
def __call__(self, progress_remaining: float) -> float:
if (1 - progress_remaining) > self.end_fraction:
return self.end
else:
return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction
def __repr__(self) -> str:
return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})"
[docs]
class ConstantSchedule:
"""
Constant schedule that always returns the same value.
Useful for fixed learning rates or clip ranges.
:param val: constant value
"""
def __init__(self, val: float):
self.val = val
def __call__(self, _: float) -> float:
return self.val
def __repr__(self) -> str:
return f"ConstantSchedule(val={self.val})"
# ===== Deprecated schedule functions ====
# only kept for backward compatibility when unpickling old models, use FloatSchedule
# and other classes like `LinearSchedule() instead
[docs]
def get_schedule_fn(value_schedule: Schedule | float) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
warnings.warn("get_schedule_fn() is deprecated, please use FloatSchedule() instead")
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return lambda progress_remaining: float(value_schedule(progress_remaining))
[docs]
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return: Linear schedule function.
"""
warnings.warn("get_linear_fn() is deprecated, please use LinearSchedule() instead")
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
[docs]
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val: constant value
:return: Constant schedule function.
"""
warnings.warn("constant_fn() is deprecated, please use ConstantSchedule() instead")
def func(_):
return val
return func
# ==== End of deprecated schedule functions ====
[docs]
def get_device(device: th.device | str = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
[docs]
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
[docs]
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
[docs]
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}"
assert (
space1.spaces.keys() == space2.spaces.keys()
), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}"
[docs]
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
[docs]
def is_vectorized_discrete_observation(observation: int | np.ndarray, observation_space: spaces.Discrete) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use () or (n_env,) for the observation shape."
)
[docs]
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
[docs]
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
[docs]
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
# We first assume that all observations are not vectorized
all_non_vectorized = True
for key, subspace in observation_space.spaces.items():
# This fails when the observation is not vectorized
# or when it has the wrong shape
if observation[key].shape != subspace.shape:
all_non_vectorized = False
break
if all_non_vectorized:
return False
all_vectorized = True
# Now we check that all observation are vectorized and have the correct shape
for key, subspace in observation_space.spaces.items():
if observation[key].shape[1:] != subspace.shape:
all_vectorized = False
break
if all_vectorized:
return True
else:
# Retrieve error message
error_msg = ""
try:
is_vectorized_observation(observation[key], observation_space.spaces[key])
except ValueError as e:
error_msg = f"{e}"
raise ValueError(
f"There seems to be a mix of vectorized and non-vectorized observations. "
f"Unexpected observation shape {observation[key].shape} for key {key} "
f"of type {observation_space.spaces[key]}. {error_msg}"
)
[docs]
def is_vectorized_observation(observation: int | np.ndarray, observation_space: spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
is_vec_obs_func_dict = {
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space) # type: ignore[operator]
else:
# for-else happens if no break is called
raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
[docs]
def safe_mean(arr: np.ndarray | list | deque) -> float:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
[docs]
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
[docs]
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
It used to be a polyfill for Python 3.19 taken from Stackoverflow #32954486.
Since Python 3.10 is the minimum version, it is kept only for legacy
and is just returning zip(..., strict=True).
:param \*iterables: iterables to ``zip()``
"""
return zip(*iterables, strict=True)
[docs]
def polyak_update(
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
for param, target_param in zip(params, target_params, strict=True):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
[docs]
def obs_as_tensor(obs: np.ndarray | dict[str, np.ndarray], device: th.device) -> th.Tensor | TensorDict:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise TypeError(f"Unrecognized type of observation {type(obs)}")
[docs]
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)
[docs]
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
:param print_info: Whether to print or not those infos
:return: Dictionary summing up the version for each relevant package
and a formatted string.
"""
env_info = {
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym
env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:
pass
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str