SAC

Soft Actor Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.

SAC is the successor of Soft Q-Learning SQL and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.

Available Policies

MlpPolicy

alias of stable_baselines3.sac.policies.SACPolicy

CnnPolicy

Policy class (with both actor and critic) for SAC.

Notes

Note

In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions.

Note

The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, to match the original paper

Can I use?

  • Recurrent policies: ❌

  • Multi processing: ❌

  • Gym spaces:

Space

Action

Observation

Discrete

✔️

Box

✔️

✔️

MultiDiscrete

✔️

MultiBinary

✔️

Example

import gym
import numpy as np

from stable_baselines3 import SAC
from stable_baselines3.sac import MlpPolicy

env = gym.make('Pendulum-v0')

model = SAC(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("sac_pendulum")

del model # remove to demonstrate saving and loading

model = SAC.load("sac_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

Results

PyBullet Environments

Results on the PyBullet benchmark (1M steps) using 3 seeds. The complete learning curves are available in the associated issue #48.

Note

Hyperparameters from the gSDE paper were used (as they are tuned for PyBullet envs).

Gaussian means that the unstructured Gaussian noise is used for exploration, gSDE (generalized State-Dependent Exploration) is used otherwise.

Environments

SAC

SAC

TD3

Gaussian

gSDE

Gaussian

HalfCheetah

2757 +/- 53

2984 +/- 202

2774 +/- 35

Ant

3146 +/- 35

3102 +/- 37

3305 +/- 43

Hopper

2422 +/- 168

2262 +/- 1

2429 +/- 126

Walker2D

2184 +/- 54

2136 +/- 67

2063 +/- 185

How to replicate the results?

Clone the rl-zoo repo:

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark (replace $ENV_ID by the envs mentioned above):

python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000

Plot the results:

python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results
python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC

Parameters

class stable_baselines3.sac.SAC(policy, env, learning_rate=0.0003, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, optimize_memory_usage=False, ent_coef='auto', target_update_interval=1, target_entropy='auto', use_sde=False, sde_sample_freq=- 1, use_sde_at_warmup=False, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]

Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo (https://github.com/rail-berkeley/softlearning/) and from Stable Baselines (https://github.com/hill-a/stable-baselines) Paper: https://arxiv.org/abs/1801.01290 Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html

Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270

Parameters
  • policy (Union[str, Type[SACPolicy]]) – The policy model to use (MlpPolicy, CnnPolicy, …)

  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str)

  • learning_rate (Union[float, Callable[[float], float]]) – learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0)

  • buffer_size (int) – size of the replay buffer

  • learning_starts (int) – how many steps of the model to collect transitions for before learning starts

  • batch_size (int) – Minibatch size for each gradient update

  • tau (float) – the soft update coefficient (“Polyak update”, between 0 and 1)

  • gamma (float) – the discount factor

  • train_freq (Union[int, Tuple[int, str]]) – Update the model every train_freq steps. Alternatively pass a tuple of frequency and unit like (5, "step") or (2, "episode").

  • gradient_steps (int) – How many gradient steps to do after each rollout (see train_freq) Set to -1 means to do as many gradient steps as steps done in the environment during the rollout.

  • action_noise (Optional[ActionNoise]) – the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.

  • optimize_memory_usage (bool) – Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195

  • ent_coef (Union[str, float]) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to ‘auto’ to learn it automatically (and ‘auto_0.1’ for using 0.1 as initial value)

  • target_update_interval (int) – update the target network every target_network_update_freq gradient steps.

  • target_entropy (Union[str, float]) – target entropy when learning ent_coef (ent_coef = 'auto')

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • use_sde_at_warmup (bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)

  • create_eval_env (bool) – Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)

  • policy_kwargs (Optional[Dict[str, Any]]) – additional arguments to be passed to the policy on creation

  • verbose (int) – the verbosity level: 0 no output, 1 info, 2 debug

  • seed (Optional[int]) – Seed for the pseudo random generators

  • device (Union[device, str]) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible.

  • _init_setup_model (bool) – Whether or not to build the network at the creation of the instance

collect_rollouts(env, callback, train_freq, replay_buffer, action_noise=None, learning_starts=0, log_interval=None)

Collect experiences and store them into a ReplayBuffer.

Parameters
  • env (VecEnv) – The training environment

  • callback (BaseCallback) – Callback that will be called at each step (and at the beginning and end of the rollout)

  • train_freq (TrainFreq) – How much experience to collect by doing rollouts of current policy. Either TrainFreq(<n>, TrainFrequencyUnit.STEP) or TrainFreq(<n>, TrainFrequencyUnit.EPISODE) with <n> being an integer greater than 0.

  • action_noise (Optional[ActionNoise]) – Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.

  • learning_starts (int) – Number of steps before learning for the warm-up phase.

  • replay_buffer (ReplayBuffer) –

  • log_interval (Optional[int]) – Log data every log_interval episodes

Return type

RolloutReturn

Returns

get_env()

Returns the current environment (can be None if not defined).

Return type

Optional[VecEnv]

Returns

The current environment

get_parameters()

Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).

Return type

Dict[str, Dict]

Returns

Mapping of from names of the objects to PyTorch state-dicts.

get_vec_normalize_env()

Return the VecNormalize wrapper of the training env if it exists.

Return type

Optional[VecNormalize]

Returns

The VecNormalize env.

learn(total_timesteps, callback=None, log_interval=4, eval_env=None, eval_freq=- 1, n_eval_episodes=5, tb_log_name='SAC', eval_log_path=None, reset_num_timesteps=True)[source]

Return a trained model.

Parameters
  • total_timesteps (int) – The total number of samples (env steps) to train on

  • callback (Union[None, Callable, List[BaseCallback], BaseCallback]) – callback(s) called at every step with state of the algorithm.

  • log_interval (int) – The number of timesteps before logging.

  • tb_log_name (str) – the name of the run for TensorBoard logging

  • eval_env (Union[Env, VecEnv, None]) – Environment that will be used to evaluate the agent

  • eval_freq (int) – Evaluate the agent every eval_freq timesteps (this may vary a little)

  • n_eval_episodes (int) – Number of episode to evaluate the agent

  • eval_log_path (Optional[str]) – Path to a folder where the evaluations will be saved

  • reset_num_timesteps (bool) – whether or not to reset the current timestep number (used in logging)

Return type

OffPolicyAlgorithm

Returns

the trained model

classmethod load(path, env=None, device='auto', **kwargs)

Load the model from a zip-file

Parameters
  • path (Union[str, Path, BufferedIOBase]) – path to the file (or a file-like) where to load the agent from

  • env (Union[Env, VecEnv, None]) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment

  • device (Union[device, str]) – Device on which the code should run.

  • kwargs – extra arguments to change the model when loading

Return type

BaseAlgorithm

load_replay_buffer(path)

Load a replay buffer from a pickle file.

Parameters

path (Union[str, Path, BufferedIOBase]) – Path to the pickled replay buffer.

Return type

None

predict(observation, state=None, mask=None, deterministic=False)

Get the model’s action(s) from an observation

Parameters
  • observation (ndarray) – the input observation

  • state (Optional[ndarray]) – The last states (can be None, used in recurrent policies)

  • mask (Optional[ndarray]) – The last masks (can be None, used in recurrent policies)

  • deterministic (bool) – Whether or not to return deterministic actions.

Return type

Tuple[ndarray, Optional[ndarray]]

Returns

the model’s action and the next state (used in recurrent policies)

save(path, exclude=None, include=None)

Save all the attributes of the object and the model parameters in a zip-file.

Parameters
  • path (Union[str, Path, BufferedIOBase]) – path to the file where the rl agent should be saved

  • exclude (Optional[Iterable[str]]) – name of parameters that should be excluded in addition to the default ones

  • include (Optional[Iterable[str]]) – name of parameters that might be excluded but should be included anyway

Return type

None

save_replay_buffer(path)

Save the replay buffer as a pickle file.

Parameters

path (Union[str, Path, BufferedIOBase]) – Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.

Return type

None

set_env(env)

Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space

Parameters

env (Union[Env, VecEnv]) – The environment for learning a policy

Return type

None

set_parameters(load_path_or_dict, exact_match=True, device='auto')

Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see get_parameters).

Parameters
  • load_path_or_iter – Location of the saved data (path or file-like, see save), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned by torch.nn.Module.state_dict().

  • exact_match (bool) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.

  • device (Union[device, str]) – Device on which the code should run.

Return type

None

set_random_seed(seed=None)

Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)

Parameters

seed (Optional[int]) –

Return type

None

train(gradient_steps, batch_size=64)[source]

Sample the replay buffer and do the updates (gradient descent and update target networks)

Return type

None

SAC Policies

stable_baselines3.sac.MlpPolicy

alias of stable_baselines3.sac.policies.SACPolicy

class stable_baselines3.sac.policies.SACPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, sde_net_arch=None, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=True)[source]

Policy class (with both actor and critic) for SAC.

Parameters
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (Union[List[int], Dict[str, List[int]], None]) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • sde_net_arch (Optional[List[int]]) – Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features.

  • use_expln (bool) – Use expln() function instead of exp() when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • clip_mean (float) – Clip the mean output when using gSDE to avoid numerical instability.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • features_extractor_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the features extractor.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

  • n_critics (int) – Number of critic networks to create.

  • share_features_extractor (bool) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)

forward(obs, deterministic=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tensor

reset_noise(batch_size=1)[source]

Sample new weights for the exploration matrix, when using gSDE.

Parameters

batch_size (int) –

Return type

None

class stable_baselines3.sac.CnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, sde_net_arch=None, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=True)[source]

Policy class (with both actor and critic) for SAC.

Parameters
  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • lr_schedule (Callable[[float], float]) – Learning rate schedule (could be constant)

  • net_arch (Union[List[int], Dict[str, List[int]], None]) – The specification of the policy and value networks.

  • activation_fn (Type[Module]) – Activation function

  • use_sde (bool) – Whether to use State Dependent Exploration or not

  • log_std_init (float) – Initial value for the log standard deviation

  • sde_net_arch (Optional[List[int]]) – Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features.

  • use_expln (bool) – Use expln() function instead of exp() when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, exp() is usually enough.

  • clip_mean (float) – Clip the mean output when using gSDE to avoid numerical instability.

  • features_extractor_class (Type[BaseFeaturesExtractor]) – Features extractor to use.

  • normalize_images (bool) – Whether to normalize images or not, dividing by 255.0 (True by default)

  • optimizer_class (Type[Optimizer]) – The optimizer to use, th.optim.Adam by default

  • optimizer_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizer

  • n_critics (int) – Number of critic networks to create.

  • share_features_extractor (bool) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)