Utils
- stable_baselines3.common.utils.check_for_correct_spaces(env, observation_space, action_space)[source]
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space
- Parameters:
env (
Union
[Env
,VecEnv
]) – Environment to check for valid spacesobservation_space (
Space
) – Observation space to check againstaction_space (
Space
) – Action space to check against
- Return type:
None
- stable_baselines3.common.utils.check_shape_equal(space1, space2)[source]
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
- Parameters:
space1 (
Space
) – Spacespace2 (
Space
) – Other space
- Return type:
None
- stable_baselines3.common.utils.configure_logger(verbose=0, tensorboard_log=None, tb_log_name='', reset_num_timesteps=True)[source]
Configure the logger’s outputs.
- Parameters:
verbose (
int
) – Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputstensorboard_log (
Optional
[str
]) – the log location for tensorboard (if None, no logging)tb_log_name (
str
) – tensorboard logreset_num_timesteps (
bool
) – Whether thenum_timesteps
attribute is reset or not. It allows to continue a previous learning curve (reset_num_timesteps=False
) or start from t=0 (reset_num_timesteps=True
, the default).
- Return type:
- Returns:
The logger object
- stable_baselines3.common.utils.constant_fn(val)[source]
Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication)
- Parameters:
val (
float
) – constant value- Return type:
Callable
[[float
],float
]- Returns:
Constant schedule function.
- stable_baselines3.common.utils.explained_variance(y_pred, y_true)[source]
Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y]
- interpretation:
ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero
- Parameters:
y_pred (
ndarray
) – the predictiony_true (
ndarray
) – the expected value
- Return type:
ndarray
- Returns:
explained variance of ypred and y
- stable_baselines3.common.utils.get_device(device='auto')[source]
Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu.
- Parameters:
device (
Union
[device
,str
]) – One for ‘auto’, ‘cuda’, ‘cpu’- Return type:
device
- Returns:
Supported Pytorch device
- stable_baselines3.common.utils.get_latest_run_id(log_path='', log_name='')[source]
Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Parameters:
log_path (
str
) – Path to the log folder containing several runs.log_name (
str
) – Name of the experiment. Each run is stored in a folder namedlog_name_1
,log_name_2
, …
- Return type:
int
- Returns:
latest run number
- stable_baselines3.common.utils.get_linear_fn(start, end, end_fraction)[source]
Create a function that interpolates linearly between start and end between
progress_remaining
= 1 andprogress_remaining
=end_fraction
. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy).- Params start:
value to start with if
progress_remaining
= 1- Params end:
value to end with if
progress_remaining
= 0- Params end_fraction:
fraction of
progress_remaining
where end is reached e.g 0.1 then end is reached after 10% of the complete training process.- Return type:
Callable
[[float
],float
]- Returns:
Linear schedule function.
- stable_baselines3.common.utils.get_parameters_by_name(model, included_names)[source]
Extract parameters from the state dict of
model
if the name contains one of the strings inincluded_names
.- Parameters:
model (
Module
) – the model where the parameters come from.included_names (
Iterable
[str
]) – substrings of names to include.
- Return type:
List
[Tensor
]- Returns:
List of parameters values (Pytorch tensors) that matches the queried names.
- stable_baselines3.common.utils.get_schedule_fn(value_schedule)[source]
Transform (if needed) learning rate and clip range (for PPO) to callable.
- Parameters:
value_schedule (
Union
[Callable
[[float
],float
],float
]) – Constant value of schedule function- Return type:
Callable
[[float
],float
]- Returns:
Schedule function (can return constant value)
- stable_baselines3.common.utils.get_system_info(print_info=True)[source]
Retrieve system and python env info for the current system.
- Parameters:
print_info (
bool
) – Whether to print or not those infos- Return type:
Tuple
[Dict
[str
,str
],str
]- Returns:
Dictionary summing up the version for each relevant package and a formatted string.
- stable_baselines3.common.utils.is_vectorized_box_observation(observation, observation_space)[source]
For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray
) – the input observation to validateobservation_space (
Box
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_dict_observation(observation, observation_space)[source]
For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray
) – the input observation to validateobservation_space (
Dict
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_discrete_observation(observation, observation_space)[source]
For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
Union
[int
,ndarray
]) – the input observation to validateobservation_space (
Discrete
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_multibinary_observation(observation, observation_space)[source]
For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray
) – the input observation to validateobservation_space (
MultiBinary
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_multidiscrete_observation(observation, observation_space)[source]
For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray
) – the input observation to validateobservation_space (
MultiDiscrete
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_observation(observation, observation_space)[source]
For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
Union
[int
,ndarray
]) – the input observation to validateobservation_space (
Space
) – the observation space
- Return type:
bool
- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.obs_as_tensor(obs, device)[source]
Moves the observation to the given device.
- Parameters:
obs (
Union
[ndarray
,Dict
[str
,ndarray
]]) –device (
device
) – PyTorch device
- Return type:
Union
[Tensor
,Dict
[str
,Tensor
]]- Returns:
PyTorch tensor of the observation on a desired device.
- stable_baselines3.common.utils.polyak_update(params, target_params, tau)[source]
Perform a Polyak average update on
target_params
usingparams
: target parameters are slowly updated towards the main parameters.tau
, the soft update coefficient controls the interpolation:tau=1
corresponds to copying the parameters to the target ones whereas nothing happens whentau=0
. The Polyak update is done in place, withno_grad
, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by1-tau
(in-place), add the new weights, scaled bytau
and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93- Parameters:
params (
Iterable
[Tensor
]) – parameters to use to update the target paramstarget_params (
Iterable
[Tensor
]) – parameters to updatetau (
float
) – the soft update coefficient (“Polyak update”, between 0 and 1)
- Return type:
None
- stable_baselines3.common.utils.safe_mean(arr)[source]
Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only.
- Parameters:
arr (
Union
[ndarray
,list
,deque
]) – Numpy array or list of values- Return type:
float
- Returns:
- stable_baselines3.common.utils.set_random_seed(seed, using_cuda=False)[source]
Seed the different random generators.
- Parameters:
seed (
int
) –using_cuda (
bool
) –
- Return type:
None
- stable_baselines3.common.utils.should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes)[source]
Helper used in
collect_rollouts()
of off-policy algorithms to determine the termination condition.- Parameters:
train_freq (
TrainFreq
) – How much experience should be collected before updating the policy.num_collected_steps (
int
) – The number of already collected steps.num_collected_episodes (
int
) – The number of already collected episodes.
- Return type:
bool
- Returns:
Whether to continue or not collecting experience by doing rollouts of the current policy.
- stable_baselines3.common.utils.update_learning_rate(optimizer, learning_rate)[source]
Update the learning rate for a given optimizer. Useful when doing linear schedule.
- Parameters:
optimizer (
Optimizer
) – Pytorch optimizerlearning_rate (
float
) – New learning rate value
- Return type:
None
- stable_baselines3.common.utils.zip_strict(*iterables)[source]
zip()
function but enforces that iterables are of equal length. RaisesValueError
if iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486.- Parameters:
*iterables (
Iterable
) – iterables tozip()
- Return type:
Iterable