Utils

stable_baselines3.common.utils.check_for_correct_spaces(env, observation_space, action_space)[source]

Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space

Parameters
  • env (Union[Env, VecEnv]) – Environment to check for valid spaces

  • observation_space (Space) – Observation space to check against

  • action_space (Space) – Action space to check against

Return type

None

stable_baselines3.common.utils.configure_logger(verbose=0, tensorboard_log=None, tb_log_name='', reset_num_timesteps=True)[source]

Configure the logger’s outputs.

Parameters
  • verbose (int) – the verbosity level: 0 no output, 1 info, 2 debug

  • tensorboard_log (Optional[str]) – the log location for tensorboard (if None, no logging)

  • tb_log_name (str) – tensorboard log

  • reset_num_timesteps (bool) – Whether the num_timesteps attribute is reset or not. It allows to continue a previous learning curve (reset_num_timesteps=False) or start from t=0 (reset_num_timesteps=True, the default).

Return type

Logger

Returns

The logger object

stable_baselines3.common.utils.constant_fn(val)[source]

Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication)

Parameters

val (float) –

Return type

Callable[[float], float]

Returns

stable_baselines3.common.utils.explained_variance(y_pred, y_true)[source]

Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y]

interpretation:

ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero

Parameters
  • y_pred (ndarray) – the prediction

  • y_true (ndarray) – the expected value

Return type

ndarray

Returns

explained variance of ypred and y

stable_baselines3.common.utils.get_device(device='auto')[source]

Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu.

Parameters

device (Union[device, str]) – One for ‘auto’, ‘cuda’, ‘cpu’

Return type

device

Returns

stable_baselines3.common.utils.get_latest_run_id(log_path=None, log_name='')[source]

Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.

Return type

int

Returns

latest run number

stable_baselines3.common.utils.get_linear_fn(start, end, end_fraction)[source]

Create a function that interpolates linearly between start and end between progress_remaining = 1 and progress_remaining = end_fraction. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy).

Params start

value to start with if progress_remaining = 1

Params end

value to end with if progress_remaining = 0

Params end_fraction

fraction of progress_remaining where end is reached e.g 0.1 then end is reached after 10% of the complete training process.

Return type

Callable[[float], float]

Returns

stable_baselines3.common.utils.get_schedule_fn(value_schedule)[source]

Transform (if needed) learning rate and clip range (for PPO) to callable.

Parameters

value_schedule (Union[Callable[[float], float], float, int]) –

Return type

Callable[[float], float]

Returns

stable_baselines3.common.utils.is_vectorized_box_observation(observation, observation_space)[source]

For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (Box) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_dict_observation(observation, observation_space)[source]

For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (Dict) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_discrete_observation(observation, observation_space)[source]

For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (Discrete) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_multibinary_observation(observation, observation_space)[source]

For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (MultiBinary) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_multidiscrete_observation(observation, observation_space)[source]

For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (MultiDiscrete) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.is_vectorized_observation(observation, observation_space)[source]

For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized.

Parameters
  • observation (ndarray) – the input observation to validate

  • observation_space (Space) – the observation space

Return type

bool

Returns

whether the given observation is vectorized or not

stable_baselines3.common.utils.obs_as_tensor(obs, device)[source]

Moves the observation to the given device.

Parameters
  • obs (Union[ndarray, Dict[Union[str, int], ndarray]]) –

  • device (device) – PyTorch device

Return type

Union[Tensor, Dict[Union[str, int], Tensor]]

Returns

PyTorch tensor of the observation on a desired device.

stable_baselines3.common.utils.polyak_update(params, target_params, tau)[source]

Perform a Polyak average update on target_params using params: target parameters are slowly updated towards the main parameters. tau, the soft update coefficient controls the interpolation: tau=1 corresponds to copying the parameters to the target ones whereas nothing happens when tau=0. The Polyak update is done in place, with no_grad, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by 1-tau (in-place), add the new weights, scaled by tau and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93

Parameters
  • params (Iterable[Parameter]) – parameters to use to update the target params

  • target_params (Iterable[Parameter]) – parameters to update

  • tau (float) – the soft update coefficient (“Polyak update”, between 0 and 1)

Return type

None

stable_baselines3.common.utils.safe_mean(arr)[source]

Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only.

Parameters

arr (Union[ndarray, list, deque]) –

Return type

ndarray

Returns

stable_baselines3.common.utils.set_random_seed(seed, using_cuda=False)[source]

Seed the different random generators.

Parameters
  • seed (int) –

  • using_cuda (bool) –

Return type

None

stable_baselines3.common.utils.should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes)[source]

Helper used in collect_rollouts() of off-policy algorithms to determine the termination condition.

Parameters
  • train_freq (TrainFreq) – How much experience should be collected before updating the policy.

  • num_collected_steps (int) – The number of already collected steps.

  • num_collected_episodes (int) – The number of already collected episodes.

Return type

bool

Returns

Whether to continue or not collecting experience by doing rollouts of the current policy.

stable_baselines3.common.utils.update_learning_rate(optimizer, learning_rate)[source]

Update the learning rate for a given optimizer. Useful when doing linear schedule.

Parameters
  • optimizer (Optimizer) –

  • learning_rate (float) –

Return type

None

stable_baselines3.common.utils.zip_strict(*iterables)[source]

zip() function but enforces that iterables are of equal length. Raises ValueError if iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486.

Parameters

*iterables – iterables to zip()

Return type

Iterable