TD3¶
Twin Delayed DDPG (TD3) Addressing Function Approximation Error in Actor-Critic Methods.
TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. We recommend reading OpenAI Spinning guide on TD3 to learn more about those.
Available Policies
alias of |
|
Policy class (with both actor and critic) for TD3. |
|
Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces. |
Notes¶
Original paper: https://arxiv.org/pdf/1802.09477.pdf
OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
Original Implementation: https://github.com/sfujim/TD3
Note
The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, to match the original paper
Can I use?¶
Recurrent policies: ❌
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
❌ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
❌ |
✔️ |
MultiBinary |
❌ |
✔️ |
Dict |
❌ |
✔️ |
Example¶
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo repository.
import gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v1")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("td3_pendulum")
env = model.get_env()
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Results¶
PyBullet Environments¶
Results on the PyBullet benchmark (1M steps) using 3 seeds. The complete learning curves are available in the associated issue #48.
Note
Hyperparameters from the gSDE paper were used (as they are tuned for PyBullet envs).
Gaussian means that the unstructured Gaussian noise is used for exploration, gSDE (generalized State-Dependent Exploration) is used otherwise.
Environments |
SAC |
SAC |
TD3 |
---|---|---|---|
Gaussian |
gSDE |
Gaussian |
|
HalfCheetah |
2757 +/- 53 |
2984 +/- 202 |
2774 +/- 35 |
Ant |
3146 +/- 35 |
3102 +/- 37 |
3305 +/- 43 |
Hopper |
2422 +/- 168 |
2262 +/- 1 |
2429 +/- 126 |
Walker2D |
2184 +/- 54 |
2136 +/- 67 |
2063 +/- 185 |
How to replicate the results?¶
Clone the rl-zoo repo:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results
python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3
Parameters¶
- class stable_baselines3.td3.TD3(policy, env, learning_rate=0.001, buffer_size=1000000, learning_starts=100, batch_size=100, tau=0.005, gamma=0.99, train_freq=(1, 'episode'), gradient_steps=-1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, policy_delay=2, target_policy_noise=0.2, target_noise_clip=0.5, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]¶
Twin Delayed DDPG (TD3) Addressing Function Approximation Error in Actor-Critic Methods.
Original implementation: https://github.com/sfujim/TD3 Paper: https://arxiv.org/abs/1802.09477 Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
- Parameters:
policy (
Union
[str
,Type
[TD3Policy
]]) – The policy model to use (MlpPolicy, CnnPolicy, …)env (
Union
[Env
,VecEnv
,str
]) – The environment to learn from (if registered in Gym, can be str)learning_rate (
Union
[float
,Callable
[[float
],float
]]) – learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0)buffer_size (
int
) – size of the replay bufferlearning_starts (
int
) – how many steps of the model to collect transitions for before learning startsbatch_size (
int
) – Minibatch size for each gradient updatetau (
float
) – the soft update coefficient (“Polyak update”, between 0 and 1)gamma (
float
) – the discount factortrain_freq (
Union
[int
,Tuple
[int
,str
]]) – Update the model everytrain_freq
steps. Alternatively pass a tuple of frequency and unit like(5, "step")
or(2, "episode")
.gradient_steps (
int
) – How many gradient steps to do after each rollout (seetrain_freq
) Set to-1
means to do as many gradient steps as steps done in the environment during the rollout.action_noise (
Optional
[ActionNoise
]) – the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.replay_buffer_class (
Optional
[Type
[ReplayBuffer
]]) – Replay buffer class to use (for instanceHerReplayBuffer
). IfNone
, it will be automatically selected.replay_buffer_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the replay buffer on creation.optimize_memory_usage (
bool
) – Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195policy_delay (
int
) – Policy and target networks will only be updated once every policy_delay steps per training steps. The Q values will be updated policy_delay more often (update every training step).target_policy_noise (
float
) – Standard deviation of Gaussian noise added to target policy (smoothing noise)target_noise_clip (
float
) – Limit for absolute value of target policy smoothing noise.policy_kwargs (
Optional
[Dict
[str
,Any
]]) – additional arguments to be passed to the policy on creationverbose (
int
) – Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messagesseed (
Optional
[int
]) – Seed for the pseudo random generatorsdevice (
Union
[device
,str
]) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible._init_setup_model (
bool
) – Whether or not to build the network at the creation of the instance
- collect_rollouts(env, callback, train_freq, replay_buffer, action_noise=None, learning_starts=0, log_interval=None)¶
Collect experiences and store them into a
ReplayBuffer
.- Parameters:
env (
VecEnv
) – The training environmentcallback (
BaseCallback
) – Callback that will be called at each step (and at the beginning and end of the rollout)train_freq (
TrainFreq
) – How much experience to collect by doing rollouts of current policy. EitherTrainFreq(<n>, TrainFrequencyUnit.STEP)
orTrainFreq(<n>, TrainFrequencyUnit.EPISODE)
with<n>
being an integer greater than 0.action_noise (
Optional
[ActionNoise
]) – Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.learning_starts (
int
) – Number of steps before learning for the warm-up phase.replay_buffer (
ReplayBuffer
) –log_interval (
Optional
[int
]) – Log data everylog_interval
episodes
- Return type:
RolloutReturn
- Returns:
- get_env()¶
Returns the current environment (can be None if not defined).
- Return type:
Optional
[VecEnv
]- Returns:
The current environment
- get_parameters()¶
Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).
- Return type:
Dict
[str
,Dict
]- Returns:
Mapping of from names of the objects to PyTorch state-dicts.
- get_vec_normalize_env()¶
Return the
VecNormalize
wrapper of the training env if it exists.- Return type:
Optional
[VecNormalize
]- Returns:
The
VecNormalize
env.
- learn(total_timesteps, callback=None, log_interval=4, tb_log_name='TD3', reset_num_timesteps=True, progress_bar=False)[source]¶
Return a trained model.
- Parameters:
total_timesteps (
int
) – The total number of samples (env steps) to train oncallback (
Union
[None
,Callable
,List
[BaseCallback
],BaseCallback
]) – callback(s) called at every step with state of the algorithm.log_interval (
int
) – The number of episodes before logging.tb_log_name (
str
) – the name of the run for TensorBoard loggingreset_num_timesteps (
bool
) – whether or not to reset the current timestep number (used in logging)progress_bar (
bool
) – Display a progress bar using tqdm and rich.
- Return type:
TypeVar
(SelfTD3
, bound= TD3)- Returns:
the trained model
- classmethod load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)¶
Load the model from a zip-file. Warning:
load
re-creates the model from scratch, it does not update it in-place! For an in-place load useset_parameters
instead.- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file (or a file-like) where to load the agent fromenv (
Union
[Env
,VecEnv
,None
]) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environmentdevice (
Union
[device
,str
]) – Device on which the code should run.custom_objects (
Optional
[Dict
[str
,Any
]]) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects inkeras.models.load_model
. Useful when you have an object in file that can not be deserialized.print_system_info (
bool
) – Whether to print system info from the saved model and the current system info (useful to debug loading issues)force_reset (
bool
) – Force call toreset()
before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597kwargs – extra arguments to change the model when loading
- Return type:
TypeVar
(SelfBaseAlgorithm
, bound= BaseAlgorithm)- Returns:
new model instance with loaded parameters
- load_replay_buffer(path, truncate_last_traj=True)¶
Load a replay buffer from a pickle file.
- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – Path to the pickled replay buffer.truncate_last_traj (
bool
) – When usingHerReplayBuffer
with online sampling: If set toTrue
, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set toFalse
, we assume that we continue the same trajectory (same episode).
- Return type:
None
- predict(observation, state=None, episode_start=None, deterministic=False)¶
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation (
Union
[ndarray
,Dict
[str
,ndarray
]]) – the input observationstate (
Optional
[Tuple
[ndarray
,...
]]) – The last hidden states (can be None, used in recurrent policies)episode_start (
Optional
[ndarray
]) – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.deterministic (
bool
) – Whether or not to return deterministic actions.
- Return type:
Tuple
[ndarray
,Optional
[Tuple
[ndarray
,...
]]]- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- save(path, exclude=None, include=None)¶
Save all the attributes of the object and the model parameters in a zip-file.
- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file where the rl agent should be savedexclude (
Optional
[Iterable
[str
]]) – name of parameters that should be excluded in addition to the default onesinclude (
Optional
[Iterable
[str
]]) – name of parameters that might be excluded but should be included anyway
- Return type:
None
- save_replay_buffer(path)¶
Save the replay buffer as a pickle file.
- Parameters:
path (
Union
[str
,Path
,BufferedIOBase
]) – Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.- Return type:
None
- set_env(env, force_reset=True)¶
Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters:
env (
Union
[Env
,VecEnv
]) – The environment for learning a policyforce_reset (
bool
) – Force call toreset()
before training to avoid unexpected behavior. See issue https://github.com/DLR-RM/stable-baselines3/issues/597
- Return type:
None
- set_logger(logger)¶
Setter for for logger object. :rtype:
None
Warning
When passing a custom logger object, this will overwrite
tensorboard_log
andverbose
settings passed to the constructor.
- set_parameters(load_path_or_dict, exact_match=True, device='auto')¶
Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see
get_parameters
).- Parameters:
load_path_or_iter – Location of the saved data (path or file-like, see
save
), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned bytorch.nn.Module.state_dict()
.exact_match (
bool
) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.device (
Union
[device
,str
]) – Device on which the code should run.
- Return type:
None
- set_random_seed(seed=None)¶
Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters:
seed (
Optional
[int
]) –- Return type:
None
TD3 Policies¶
- stable_baselines3.td3.MlpPolicy¶
alias of
TD3Policy
- class stable_baselines3.td3.policies.TD3Policy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]
Policy class (with both actor and critic) for TD3.
- Parameters:
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionfeatures_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizern_critics (
int
) – Number of critic networks to create.share_features_extractor (
bool
) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)
- forward(observation, deterministic=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- set_training_mode(mode)[source]
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
- Parameters:
mode (
bool
) – if true, set to training mode, else set to evaluation mode- Return type:
None
- class stable_baselines3.td3.CnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]¶
Policy class (with both actor and critic) for TD3.
- Parameters:
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionfeatures_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizern_critics (
int
) – Number of critic networks to create.share_features_extractor (
bool
) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)
- class stable_baselines3.td3.MultiInputPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]¶
Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces.
- Parameters:
observation_space (
Dict
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionfeatures_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizern_critics (
int
) – Number of critic networks to create.share_features_extractor (
bool
) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)