Migrating from Stable-Baselines

This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3).

It also references the main changes.

Overview

Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2). Most of the changes are to ensure more consistency and are internal ones. Because of the backend change, from Tensorflow to PyTorch, the internal code is much much readable and easy to debug at the cost of some speed (dynamic graph vs static graph., see Issue #90) However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs (see Issue #48 and Issue #49) so you should not expect performance drop when switching from SB2 to SB3.

How to migrate?

In most cases, replacing from stable_baselines by from stable_baselines3 will be sufficient. Some files were moved to the common folder (cf below) and could result to import errors. Some algorithms were removed because of their complexity to improve the maintainability of the project. We recommend reading this guide carefully to understand all the changes that were made. You can also take a look at the rl-zoo3 and compare the imports to the rl-zoo of SB2 to have a concrete example of successful migration.

Note

If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used, using torch.set_num_threads(1) or OMP_NUM_THREADS=1, see issue #122 and issue #90.

Breaking Changes

  • SB3 requires python 3.6+ (instead of python 3.5+ for SB2)

  • Dropped MPI support

  • Dropped layer normalized policies (MlpLnLstmPolicy, CnnLnLstmPolicy)

  • LSTM policies (`MlpLstmPolicy`, `CnnLstmPolicy`) are not supported for the time being

  • Dropped parameter noise for DDPG and DQN

  • PPO is now closer to the original implementation (no clipping of the value function by default), cf PPO section below

  • Orthogonal initialization is only used by A2C/PPO

  • The features extractor (CNN extractor) is shared between policy and q-networks for DDPG/SAC/TD3 and only the policy loss used to update it (much faster)

  • Tensorboard legacy logging was dropped in favor of having one logger for the terminal and Tensorboard (cf Tensorboard integration)

  • We dropped ACKTR/ACER support because of their complexity compared to simpler alternatives (PPO, SAC, TD3) performing as good.

  • We dropped GAIL support as we are focusing on model-free RL only, you can however take a look at the imitation project which implements GAIL and other imitation learning algorithms on top of SB3.

  • action_probability is currently not implemented in the base class

  • pretrain() method for behavior cloning was removed (see issue #27)

You can take a look at the issue about SB3 implementation design for more details.

Moved Files

  • bench/monitor.py -> common/monitor.py

  • logger.py -> common/logger.py

  • results_plotter.py -> common/results_plotter.py

  • common/cmd_util.py -> common/env_util.py

Utility functions are no longer exported from common module, you should import them with their absolute path, e.g.:

from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed

instead of from stable_baselines3.common import make_atari_env

Changes and renaming in parameters

Base-class (all algorithms)

  • load_parameters -> set_parameters

    • get/set_parameters return a dictionary mapping object names to their respective PyTorch tensors and other objects representing their parameters, instead of simpler mapping of parameter name to a NumPy array. These functions also return PyTorch tensors rather than NumPy arrays.

Policies

  • cnn_extractor -> features_extractor, as features_extractor in now used with MlpPolicy too

A2C

  • epsilon -> rms_prop_eps

  • lr_schedule is part of learning_rate (it can be a callable).

  • alpha, momentum are modifiable through policy_kwargs key optimizer_kwargs.

Warning

PyTorch implementation of RMSprop differs from Tensorflow’s, which leads to different and potentially more unstable results. Use stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike optimizer to match the results with TensorFlow’s implementation. This can be done through policy_kwargs: A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))

PPO

  • cliprange -> clip_range

  • cliprange_vf -> clip_range_vf

  • nminibatches -> batch_size

Warning

nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches

  • clip_range_vf behavior for PPO is slightly different: Set it to None (default) to deactivate clipping (in SB2, you had to pass -1, None meant to use clip_range for the clipping)

  • lam -> gae_lambda

  • noptepochs -> n_epochs

PPO default hyperparameters are the one tuned for continuous control environment. We recommend taking a look at the RL Zoo for hyperparameters tuned for Atari games.

DQN

Only the vanilla DQN is implemented right now but extensions will follow. Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.

DDPG

DDPG now follows the same interface as SAC/TD3. For state/reward normalization, you should use VecNormalize as for all other algorithms.

SAC/TD3

SAC/TD3 now accept any number of critics, e.g. policy_kwargs=dict(n_critics=3), instead of only two before.

Note

SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers. DDPG is using TD3 defaults.

SAC

SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation). Despite this change, no change in performance should be expected.

Note

SAC predict() method has now deterministic=False by default for consistency. To match SB2 behavior, you need to explicitly pass deterministic=True

HER

The HER implementation now also supports online sampling of the new goals. This is done in a vectorized version. The goal selection strategy RANDOM is no longer supported. HER now supports VecNormalize wrapper but only when online_sampling=True. For performance reasons, the maximum number of steps per episodes must be specified (see HER documentation).

New logger API

  • Methods were renamed in the logger:

    • logkv -> record, writekvs -> write, writeseq -> write_sequence,

    • logkvs -> record_dict, dumpkvs -> dump,

    • getkvs -> get_log_dict, logkv_mean -> record_mean,

Internal Changes

Please read the Developer Guide section.

New Features (SB3 vs SB2)

  • Much cleaner and consistent base code (and no more warnings =D!) and static type checks

  • Independent saving/loading/predict for policies

  • A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)

  • Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719)

  • Proper evaluation (using separate env) is included in the base class (using EvalCallback), if you pass the environment as a string, you can pass create_eval_env=True to the algorithm constructor.

  • Better saving/loading: optimizers are now included in the saved parameters and there is two new methods save_replay_buffer and load_replay_buffer for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)

  • You can pass optimizer_class and optimizer_kwargs to policy_kwargs in order to easily customize optimizers

  • Seeding now works properly to have deterministic results

  • Replay buffer does not grow, allocate everything at build time (faster)

  • We added a memory efficient replay buffer variant (pass optimize_memory_usage=True to the constructor), it reduces drastically the memory used especially when using images

  • You can specify an arbitrary number of critics for SAC/TD3 (e.g. policy_kwargs=dict(n_critics=3))