Utils¶
-
stable_baselines3.common.utils.
check_for_correct_spaces
(env, observation_space, action_space)[source]¶ Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space
- Parameters
env (
Union
[Env
,VecEnv
]) – Environment to check for valid spacesobservation_space (
Space
) – Observation space to check againstaction_space (
Space
) – Action space to check against
- Return type
None
-
stable_baselines3.common.utils.
configure_logger
(verbose=0, tensorboard_log=None, tb_log_name='', reset_num_timesteps=True)[source]¶ Configure the logger’s outputs.
- Parameters
verbose (
int
) – the verbosity level: 0 no output, 1 info, 2 debugtensorboard_log (
Optional
[str
]) – the log location for tensorboard (if None, no logging)tb_log_name (
str
) – tensorboard log
- Return type
None
-
stable_baselines3.common.utils.
constant_fn
(val)[source]¶ Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication)
- Parameters
val (
float
) –- Return type
Callable
[[float
],float
]- Returns
-
stable_baselines3.common.utils.
explained_variance
(y_pred, y_true)[source]¶ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y]
- interpretation:
ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero
- Parameters
y_pred (
ndarray
) – the predictiony_true (
ndarray
) – the expected value
- Return type
ndarray
- Returns
explained variance of ypred and y
-
stable_baselines3.common.utils.
get_device
(device='auto')[source]¶ Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu.
- Parameters
device (
Union
[device
,str
]) – One for ‘auto’, ‘cuda’, ‘cpu’- Return type
device
- Returns
-
stable_baselines3.common.utils.
get_latest_run_id
(log_path=None, log_name='')[source]¶ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Return type
int
- Returns
latest run number
-
stable_baselines3.common.utils.
get_linear_fn
(start, end, end_fraction)[source]¶ Create a function that interpolates linearly between start and end between
progress_remaining
= 1 andprogress_remaining
=end_fraction
. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy).- Params start
value to start with if
progress_remaining
= 1- Params end
value to end with if
progress_remaining
= 0- Params end_fraction
fraction of
progress_remaining
where end is reached e.g 0.1 then end is reached after 10% of the complete training process.- Return type
Callable
[[float
],float
]- Returns
-
stable_baselines3.common.utils.
get_schedule_fn
(value_schedule)[source]¶ Transform (if needed) learning rate and clip range (for PPO) to callable.
- Parameters
value_schedule (
Union
[Callable
[[float
],float
],float
,int
]) –- Return type
Callable
[[float
],float
]- Returns
-
stable_baselines3.common.utils.
is_vectorized_observation
(observation, observation_space)[source]¶ For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters
observation (
ndarray
) – the input observation to validateobservation_space (
Space
) – the observation space
- Return type
bool
- Returns
whether the given observation is vectorized or not
-
stable_baselines3.common.utils.
polyak_update
(params, target_params, tau)[source]¶ Perform a Polyak average update on
target_params
usingparams
: target parameters are slowly updated towards the main parameters.tau
, the soft update coefficient controls the interpolation:tau=1
corresponds to copying the parameters to the target ones whereas nothing happens whentau=0
. The Polyak update is done in place, withno_grad
, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by1-tau
(in-place), add the new weights, scaled bytau
and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93- Parameters
params (
Iterable
[Parameter
]) – parameters to use to update the target paramstarget_params (
Iterable
[Parameter
]) – parameters to updatetau (
float
) – the soft update coefficient (“Polyak update”, between 0 and 1)
- Return type
None
-
stable_baselines3.common.utils.
safe_mean
(arr)[source]¶ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only.
- Parameters
arr (
Union
[ndarray
,list
,deque
]) –- Return type
ndarray
- Returns
-
stable_baselines3.common.utils.
set_random_seed
(seed, using_cuda=False)[source]¶ Seed the different random generators :type seed:
int
:param seed: :type using_cuda:bool
:param using_cuda:- Return type
None
-
stable_baselines3.common.utils.
should_collect_more_steps
(train_freq, num_collected_steps, num_collected_episodes)[source]¶ Helper used in
collect_rollouts()
of off-policy algorithms to determine the termination condition.- Parameters
train_freq (
TrainFreq
) – How much experience should be collected before updating the policy.num_collected_steps (
int
) – The number of already collected steps.num_collected_episodes (
int
) – The number of already collected episodes.
- Return type
bool
- Returns
Whether to continue or not collecting experience by doing rollouts of the current policy.
-
stable_baselines3.common.utils.
update_learning_rate
(optimizer, learning_rate)[source]¶ Update the learning rate for a given optimizer. Useful when doing linear schedule.
- Parameters
optimizer (
Optimizer
) –learning_rate (
float
) –
- Return type
None
-
stable_baselines3.common.utils.
zip_strict
(*iterables)[source]¶ zip()
function but enforces that iterables are of equal length. RaisesValueError
if iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486.- Parameters
*iterables – iterables to
zip()
- Return type
Iterable