Source code for stable_baselines3.common.utils

import glob
import os
import platform
import random
import re
from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union

import cloudpickle
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces

import stable_baselines3 as sb3

# Check if tensorboard is available for pytorch
try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    SummaryWriter = None  # type: ignore[misc, assignment]

from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit


[docs]def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators. :param seed: :param using_cuda: """ # Seed python RNG random.seed(seed) # Seed numpy RNG np.random.seed(seed) # seed the RNG for all devices (both CPU and CUDA) th.manual_seed(seed) if using_cuda: # Deterministic operations for CuDNN, it may impact performances th.backends.cudnn.deterministic = True th.backends.cudnn.benchmark = False
# From stable baselines
[docs]def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] interpretation: ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero :param y_pred: the prediction :param y_true: the expected value :return: explained variance of ypred and y """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
[docs]def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: """ Update the learning rate for a given optimizer. Useful when doing linear schedule. :param optimizer: Pytorch optimizer :param learning_rate: New learning rate value """ for param_group in optimizer.param_groups: param_group["lr"] = learning_rate
[docs]def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule: """ Transform (if needed) learning rate and clip range (for PPO) to callable. :param value_schedule: Constant value of schedule function :return: Schedule function (can return constant value) """ # If the passed schedule is a float # create a constant function if isinstance(value_schedule, (float, int)): # Cast to float to avoid errors value_schedule = constant_fn(float(value_schedule)) else: assert callable(value_schedule) # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900 # Some types are have odd behaviors when part of a Schedule, like numpy floats return lambda progress_remaining: float(value_schedule(progress_remaining))
[docs]def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: """ Create a function that interpolates linearly between start and end between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy). :params start: value to start with if ``progress_remaining`` = 1 :params end: value to end with if ``progress_remaining`` = 0 :params end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1 then end is reached after 10% of the complete training process. :return: Linear schedule function. """ def func(progress_remaining: float) -> float: if (1 - progress_remaining) > end_fraction: return end else: return start + (1 - progress_remaining) * (end - start) / end_fraction return func
[docs]def constant_fn(val: float) -> Schedule: """ Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) :param val: constant value :return: Constant schedule function. """ def func(_): return val return func
[docs]def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu. :param device: One for 'auto', 'cuda', 'cpu' :return: Supported Pytorch device """ # Cuda by default if device == "auto": device = "cuda" # Force conversion to th.device device = th.device(device) # Cuda not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") return device
[docs]def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. :param log_path: Path to the log folder containing several runs. :param log_name: Name of the experiment. Each run is stored in a folder named ``log_name_1``, ``log_name_2``, ... :return: latest run number """ max_run_id = 0 for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): file_name = path.split(os.sep)[-1] ext = file_name.split("_")[-1] if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: max_run_id = int(ext) return max_run_id
[docs]def configure_logger( verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True, ) -> Logger: """ Configure the logger's outputs. :param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs :param tensorboard_log: the log location for tensorboard (if None, no logging) :param tb_log_name: tensorboard log :param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not. It allows to continue a previous learning curve (``reset_num_timesteps=False``) or start from t=0 (``reset_num_timesteps=True``, the default). :return: The logger object """ save_path, format_strings = None, ["stdout"] if tensorboard_log is not None and SummaryWriter is None: raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.") if tensorboard_log is not None and SummaryWriter is not None: latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name) if not reset_num_timesteps: # Continue training in the same directory latest_run_id -= 1 save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}") if verbose >= 1: format_strings = ["stdout", "tensorboard"] else: format_strings = ["tensorboard"] elif verbose == 0: format_strings = [""] return configure(save_path, format_strings=format_strings)
[docs]def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None: """ Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space :param env: Environment to check for valid spaces :param observation_space: Observation space to check against :param action_space: Action space to check against """ if observation_space != env.observation_space: raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}") if action_space != env.action_space: raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
[docs]def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None: """ If the spaces are Box, check that they have the same shape. If the spaces are Dict, it recursively checks the subspaces. :param space1: Space :param space2: Other space """ if isinstance(space1, spaces.Dict): assert isinstance(space2, spaces.Dict), "spaces must be of the same type" assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys" for key in space1.spaces.keys(): check_shape_equal(space1.spaces[key], space2.spaces[key]) elif isinstance(space1, spaces.Box): assert space1.shape == space2.shape, "spaces must have the same shape"
[docs]def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool: """ For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == observation_space.shape: return False elif observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for " + f"Box environment, please use {observation_space.shape} " + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) )
[docs]def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool: """ For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()' return False elif len(observation.shape) == 1: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for " + "Discrete environment, please use () or (n_env,) for the observation shape." )
[docs]def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool: """ For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == (len(observation_space.nvec),): return False elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + f"environment, please use ({len(observation_space.nvec)},) or " + f"(n_env, {len(observation_space.nvec)}) for the observation shape." )
[docs]def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool: """ For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == observation_space.shape: return False elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + f"environment, please use {observation_space.shape} or " + f"(n_env, {observation_space.n}) for the observation shape." )
[docs]def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool: """ For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ # We first assume that all observations are not vectorized all_non_vectorized = True for key, subspace in observation_space.spaces.items(): # This fails when the observation is not vectorized # or when it has the wrong shape if observation[key].shape != subspace.shape: all_non_vectorized = False break if all_non_vectorized: return False all_vectorized = True # Now we check that all observation are vectorized and have the correct shape for key, subspace in observation_space.spaces.items(): if observation[key].shape[1:] != subspace.shape: all_vectorized = False break if all_vectorized: return True else: # Retrieve error message error_msg = "" try: is_vectorized_observation(observation[key], observation_space.spaces[key]) except ValueError as e: error_msg = f"{e}" raise ValueError( f"There seems to be a mix of vectorized and non-vectorized observations. " f"Unexpected observation shape {observation[key].shape} for key {key} " f"of type {observation_space.spaces[key]}. {error_msg}" )
[docs]def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool: """ For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ is_vec_obs_func_dict = { spaces.Box: is_vectorized_box_observation, spaces.Discrete: is_vectorized_discrete_observation, spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, spaces.MultiBinary: is_vectorized_multibinary_observation, spaces.Dict: is_vectorized_dict_observation, } for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): if isinstance(observation_space, space_type): return is_vec_obs_func(observation, observation_space) # type: ignore[operator] else: # for-else happens if no break is called raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
[docs]def safe_mean(arr: Union[np.ndarray, list, deque]) -> float: """ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. :param arr: Numpy array or list of values :return: """ return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
[docs]def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]: """ Extract parameters from the state dict of ``model`` if the name contains one of the strings in ``included_names``. :param model: the model where the parameters come from. :param included_names: substrings of names to include. :return: List of parameters values (Pytorch tensors) that matches the queried names. """ return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
[docs]def zip_strict(*iterables: Iterable) -> Iterable: r""" ``zip()`` function but enforces that iterables are of equal length. Raises ``ValueError`` if iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486. :param \*iterables: iterables to ``zip()`` """ # As in Stackoverflow #32954486, use # new object for "empty" in case we have # Nones in iterable. sentinel = object() for combo in zip_longest(*iterables, fillvalue=sentinel): if sentinel in combo: raise ValueError("Iterables have different lengths") yield combo
[docs]def polyak_update( params: Iterable[th.Tensor], target_params: Iterable[th.Tensor], tau: float, ) -> None: """ Perform a Polyak average update on ``target_params`` using ``params``: target parameters are slowly updated towards the main parameters. ``tau``, the soft update coefficient controls the interpolation: ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``. The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93 :param params: parameters to use to update the target params :param target_params: parameters to update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) """ with th.no_grad(): # zip does not raise an exception if length of parameters does not match. for param, target_param in zip_strict(params, target_params): target_param.data.mul_(1 - tau) th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
[docs]def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: """ Moves the observation to the given device. :param obs: :param device: PyTorch device :return: PyTorch tensor of the observation on a desired device. """ if isinstance(obs, np.ndarray): return th.as_tensor(obs, device=device) elif isinstance(obs, dict): return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} else: raise Exception(f"Unrecognized type of observation {type(obs)}")
[docs]def should_collect_more_steps( train_freq: TrainFreq, num_collected_steps: int, num_collected_episodes: int, ) -> bool: """ Helper used in ``collect_rollouts()`` of off-policy algorithms to determine the termination condition. :param train_freq: How much experience should be collected before updating the policy. :param num_collected_steps: The number of already collected steps. :param num_collected_episodes: The number of already collected episodes. :return: Whether to continue or not collecting experience by doing rollouts of the current policy. """ if train_freq.unit == TrainFrequencyUnit.STEP: return num_collected_steps < train_freq.frequency elif train_freq.unit == TrainFrequencyUnit.EPISODE: return num_collected_episodes < train_freq.frequency else: raise ValueError( "The unit of the `train_freq` must be either TrainFrequencyUnit.STEP " f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!" )
[docs]def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: """ Retrieve system and python env info for the current system. :param print_info: Whether to print or not those infos :return: Dictionary summing up the version for each relevant package and a formatted string. """ env_info = { # In OS, a regex is used to add a space between a "#" and a number to avoid # wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42". "OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"), "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, "GPU Enabled": str(th.cuda.is_available()), "Numpy": np.__version__, "Cloudpickle": cloudpickle.__version__, "Gymnasium": gym.__version__, } try: import gym as openai_gym env_info.update({"OpenAI Gym": openai_gym.__version__}) except ImportError: pass env_info_str = "" for key, value in env_info.items(): env_info_str += f"- {key}: {value}\n" if print_info: print(env_info_str) return env_info, env_info_str