Utils¶
- stable_baselines3.common.utils.check_for_correct_spaces(env, observation_space, action_space)[source]¶
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space
- Parameters:
env (
Union[Env,VecEnv]) – Environment to check for valid spacesobservation_space (
Space) – Observation space to check againstaction_space (
Space) – Action space to check against
- Return type:
None
- stable_baselines3.common.utils.check_shape_equal(space1, space2)[source]¶
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
- Parameters:
space1 (
Space) – Spacespace2 (
Space) – Other space
- Return type:
None
- stable_baselines3.common.utils.configure_logger(verbose=0, tensorboard_log=None, tb_log_name='', reset_num_timesteps=True)[source]¶
Configure the logger’s outputs.
- Parameters:
verbose (
int) – Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputstensorboard_log (
Optional[str]) – the log location for tensorboard (if None, no logging)tb_log_name (
str) – tensorboard logreset_num_timesteps (
bool) – Whether thenum_timestepsattribute is reset or not. It allows to continue a previous learning curve (reset_num_timesteps=False) or start from t=0 (reset_num_timesteps=True, the default).
- Return type:
- Returns:
The logger object
- stable_baselines3.common.utils.constant_fn(val)[source]¶
Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication)
- Parameters:
val (
float) – constant value- Return type:
Callable[[float],float]- Returns:
Constant schedule function.
- stable_baselines3.common.utils.explained_variance(y_pred, y_true)[source]¶
Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y]
- interpretation:
ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero
- Parameters:
y_pred (
ndarray) – the predictiony_true (
ndarray) – the expected value
- Return type:
ndarray- Returns:
explained variance of ypred and y
- stable_baselines3.common.utils.get_device(device='auto')[source]¶
Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu.
- Parameters:
device (
Union[device,str]) – One for ‘auto’, ‘cuda’, ‘cpu’- Return type:
device- Returns:
Supported Pytorch device
- stable_baselines3.common.utils.get_latest_run_id(log_path='', log_name='')[source]¶
Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Parameters:
log_path (
str) – Path to the log folder containing several runs.log_name (
str) – Name of the experiment. Each run is stored in a folder namedlog_name_1,log_name_2, …
- Return type:
int- Returns:
latest run number
- stable_baselines3.common.utils.get_linear_fn(start, end, end_fraction)[source]¶
Create a function that interpolates linearly between start and end between
progress_remaining= 1 andprogress_remaining=end_fraction. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy).- Params start:
value to start with if
progress_remaining= 1- Params end:
value to end with if
progress_remaining= 0- Params end_fraction:
fraction of
progress_remainingwhere end is reached e.g 0.1 then end is reached after 10% of the complete training process.- Return type:
Callable[[float],float]- Returns:
Linear schedule function.
- stable_baselines3.common.utils.get_parameters_by_name(model, included_names)[source]¶
Extract parameters from the state dict of
modelif the name contains one of the strings inincluded_names.- Parameters:
model (
Module) – the model where the parameters come from.included_names (
Iterable[str]) – substrings of names to include.
- Return type:
List[Tensor]- Returns:
List of parameters values (Pytorch tensors) that matches the queried names.
- stable_baselines3.common.utils.get_schedule_fn(value_schedule)[source]¶
Transform (if needed) learning rate and clip range (for PPO) to callable.
- Parameters:
value_schedule (
Union[Callable[[float],float],float]) – Constant value of schedule function- Return type:
Callable[[float],float]- Returns:
Schedule function (can return constant value)
- stable_baselines3.common.utils.get_system_info(print_info=True)[source]¶
Retrieve system and python env info for the current system.
- Parameters:
print_info (
bool) – Whether to print or not those infos- Return type:
Tuple[Dict[str,str],str]- Returns:
Dictionary summing up the version for each relevant package and a formatted string.
- stable_baselines3.common.utils.is_vectorized_box_observation(observation, observation_space)[source]¶
For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray) – the input observation to validateobservation_space (
Box) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_dict_observation(observation, observation_space)[source]¶
For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray) – the input observation to validateobservation_space (
Dict) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_discrete_observation(observation, observation_space)[source]¶
For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
Union[int,ndarray]) – the input observation to validateobservation_space (
Discrete) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_multibinary_observation(observation, observation_space)[source]¶
For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray) – the input observation to validateobservation_space (
MultiBinary) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_multidiscrete_observation(observation, observation_space)[source]¶
For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
ndarray) – the input observation to validateobservation_space (
MultiDiscrete) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.is_vectorized_observation(observation, observation_space)[source]¶
For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized.
- Parameters:
observation (
Union[int,ndarray]) – the input observation to validateobservation_space (
Space) – the observation space
- Return type:
bool- Returns:
whether the given observation is vectorized or not
- stable_baselines3.common.utils.obs_as_tensor(obs, device)[source]¶
Moves the observation to the given device.
- Parameters:
obs (
Union[ndarray,Dict[str,ndarray]]) –device (
device) – PyTorch device
- Return type:
Union[Tensor,Dict[str,Tensor]]- Returns:
PyTorch tensor of the observation on a desired device.
- stable_baselines3.common.utils.polyak_update(params, target_params, tau)[source]¶
Perform a Polyak average update on
target_paramsusingparams: target parameters are slowly updated towards the main parameters.tau, the soft update coefficient controls the interpolation:tau=1corresponds to copying the parameters to the target ones whereas nothing happens whentau=0. The Polyak update is done in place, withno_grad, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by1-tau(in-place), add the new weights, scaled bytauand store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93- Parameters:
params (
Iterable[Tensor]) – parameters to use to update the target paramstarget_params (
Iterable[Tensor]) – parameters to updatetau (
float) – the soft update coefficient (“Polyak update”, between 0 and 1)
- Return type:
None
- stable_baselines3.common.utils.safe_mean(arr)[source]¶
Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only.
- Parameters:
arr (
Union[ndarray,list,deque]) – Numpy array or list of values- Return type:
ndarray- Returns:
- stable_baselines3.common.utils.set_random_seed(seed, using_cuda=False)[source]¶
Seed the different random generators.
- Parameters:
seed (
int) –using_cuda (
bool) –
- Return type:
None
- stable_baselines3.common.utils.should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes)[source]¶
Helper used in
collect_rollouts()of off-policy algorithms to determine the termination condition.- Parameters:
train_freq (
TrainFreq) – How much experience should be collected before updating the policy.num_collected_steps (
int) – The number of already collected steps.num_collected_episodes (
int) – The number of already collected episodes.
- Return type:
bool- Returns:
Whether to continue or not collecting experience by doing rollouts of the current policy.
- stable_baselines3.common.utils.update_learning_rate(optimizer, learning_rate)[source]¶
Update the learning rate for a given optimizer. Useful when doing linear schedule.
- Parameters:
optimizer (
Optimizer) – Pytorch optimizerlearning_rate (
float) – New learning rate value
- Return type:
None
- stable_baselines3.common.utils.zip_strict(*iterables)[source]¶
zip()function but enforces that iterables are of equal length. RaisesValueErrorif iterables not of equal length. Code inspired by Stackoverflow answer for question #32954486.- Parameters:
*iterables (
Iterable) – iterables tozip()- Return type:
Iterable