Utils

stable_baselines3.common.utils.check_for_correct_spaces(env, observation_space, action_space)[source]

Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space

Parameters:
  • env (Union[Env, VecEnv]) – Environment to check for valid spaces

  • observation_space (Space) – Observation space to check against

  • action_space (Space) – Action space to check against

Return type:

None

stable_baselines3.common.utils.check_shape_equal(space1, space2)[source]

If the spaces are Box, check that they have the same shape.

If the spaces are Dict, it recursively checks the subspaces.

Parameters:
  • space1 (Space) – Space

  • space2 (Space) – Other space

Return type:

None

stable_baselines3.common.utils.configure_logger(verbose=0, tensorboard_log=None, tb_log_name='', reset_num_timesteps=True)[source]

Configure the logger’s outputs.

Parameters:
  • verbose (int) – Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs

  • tensorboard_log (Optional[str]) – the log location for tensorboard (if None, no logging)

  • tb_log_name (str) – tensorboard log

  • reset_num_timesteps (bool) – Whether the num_timesteps attribute is reset or not. It allows to continue a previous learning curve (reset_num_timesteps=False) or start from t=0 (reset_num_timesteps=True, the default).

Return type:

Logger

Returns:

The logger object

stable_baselines3.common.utils.constant_fn(val)[source]

Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication)

Parameters:

val (float) – constant value

Return type:

Callable[[float], float]

Returns:

Constant schedule function.

stable_baselines3.common.utils.explained_variance(y_pred, y_true)[source]

Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y]

interpretation:

ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero

Parameters:
  • y_pred (ndarray) – the prediction

  • y_true (ndarray) – the expected value

Return type:

ndarray

Returns:

explained variance of ypred and y

stable_baselines3.common.utils.get_device(device='auto')[source]

Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu.

Parameters:

device (Union[device, str]) – One for ‘auto’, ‘cuda’, ‘cpu’

Return type:

device

Returns:

Supported Pytorch device

stable_baselines3.common.utils.get_latest_run_id(log_path='', log_name='')[source]

Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.

Parameters:
  • log_path (str) – Path to the log folder containing several runs.

  • log_name (str) – Name of the experiment. Each run is stored in a folder named log_name_1, log_name_2, …

Return type:

int

Returns:

latest run number

stable_baselines3.common.utils.get_linear_fn(start, end, end_fraction)[source]

Create a function that interpolates linearly between start and end between progress_remaining = 1 and progress_remaining = end_fraction. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy).

Params start:

value to start with if progress_remaining = 1

Params end:

value to end with if progress_remaining = 0

Params end_fraction:

fraction of progress_remaining where end is reached e.g 0.1 then end is reached after 10% of the complete training process.

Return type:

Callable[[float], float]

Returns:

Linear schedule function.

stable_baselines3.common.utils.get_parameters_by_name(model, included_names)[source]

Extract parameters from the state dict of model if the name contains one of the strings in included_names.

Parameters:
  • model (Module) – the model where the parameters come from.

  • included_names (Iterable[str]) – substrings of names to include.

Return type:

List[Tensor]

Returns:

List of parameters values (Pytorch tensors) that matches the queried names.

stable_baselines3.common.utils.get_schedule_fn(value_schedule)[source]

Transform (if needed) learning rate and clip range (for PPO) to callable.

Parameters:

value_schedule (Union[Callable[[float], float], float]) – Constant value of schedule function

Return type:

Callable[[float], float]

Returns:

Schedule function (can return constant value)

stable_baselines3.common.utils.get_system_info(print_info=True)[source]

Retrieve system and python env info for the current system.

Parameters:

print_info (bool) – Whether to print or not those infos

Return type:

Tuple[Dict[str, str], str]

Returns:

Dictionary summing up the version for each relevant package and a formatted string.

stable_baselines3.common.utils.is_vectorized_box_observation(observation, observation_space)[source]

For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (ndarray) – the input observation to validate

  • observation_space (Box) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_dict_observation(observation, observation_space)[source]

For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (ndarray) – the input observation to validate

  • observation_space (Dict) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_discrete_observation(observation, observation_space)[source]

For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (Union[int, ndarray]) – the input observation to validate

  • observation_space (Discrete) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_multibinary_observation(observation, observation_space)[source]

For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (ndarray) – the input observation to validate

  • observation_space (MultiBinary) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_multidiscrete_observation(observation, observation_space)[source]

For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (ndarray) – the input observation to validate

  • observation_space (MultiDiscrete) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_observation(observation, observation_space)[source]

For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters:
  • observation (Union[int, ndarray]) – the input observation to validate

  • observation_space (Space) – the observation space

Return type:

bool

Returns:

whether the given observation is vectorized or not

stable_baselines3.common.utils.obs_as_tensor(obs, device)[source]

Moves the observation to the given device.

Parameters:
  • obs (Union[ndarray, Dict[Union[str, int], ndarray]]) –

  • device (device) – PyTorch device

Return type:

Union[Tensor, Dict[Union[str, int], Tensor]]

Returns:

PyTorch tensor of the observation on a desired device.

stable_baselines3.common.utils.polyak_update(params, target_params, tau)[source]

Perform a Polyak average update on target_params using params: target parameters are slowly updated towards the main parameters. tau, the soft update coefficient controls the interpolation: tau=1 corresponds to copying the parameters to the target ones whereas nothing happens when tau=0. The Polyak update is done in place, with no_grad, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by 1-tau (in-place), add the new weights, scaled by tau and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93

Parameters:
  • params (Iterable[Tensor]) – parameters to use to update the target params

  • target_params (Iterable[Tensor]) – parameters to update

  • tau (float) – the soft update coefficient (“Polyak update”, between 0 and 1)

Return type:

None

stable_baselines3.common.utils.safe_mean(arr)[source]

Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only.

Parameters:

arr (Union[ndarray, list, deque]) – Numpy array or list of values

Return type:

ndarray

Returns:

stable_baselines3.common.utils.set_random_seed(seed, using_cuda=False)[source]

Seed the different random generators.

Parameters:
  • seed (int) –

  • using_cuda (bool) –

Return type:

None

stable_baselines3.common.utils.should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes)[source]

Helper used in collect_rollouts() of off-policy algorithms to determine the termination condition.

Parameters:
  • train_freq (TrainFreq) – How much experience should be collected before updating the policy.

  • num_collected_steps (int) – The number of already collected steps.

  • num_collected_episodes (int) – The number of already collected episodes.

Return type:

bool

Returns:

Whether to continue or not collecting experience by doing rollouts of the current policy.

stable_baselines3.common.utils.update_learning_rate(optimizer, learning_rate)[source]

Update the learning rate for a given optimizer. Useful when doing linear schedule.

Parameters:
  • optimizer (Optimizer) – Pytorch optimizer

  • learning_rate (float) – New learning rate value

Return type:

None

stable_baselines3.common.utils.zip_strict(*iterables)[source]

zip() function but enforces that iterables are of equal length. Raises ValueError if iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486.

Parameters:

*iterables (Iterable) – iterables to zip()

Return type:

Iterable