HER¶
Hindsight Experience Replay (HER)
HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example). HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout. It creates “virtual” transitions by relabeling transitions (changing the desired goal) from past episodes.
Warning
Starting from Stable Baselines3 v1.1.0, HER
is no longer a separate algorithm
but a replay buffer class HerReplayBuffer
that must be passed to an off-policy algorithm
when using MultiInputPolicy
(to have Dict observation support).
Warning
HER requires the environment to follow the legacy gym_robotics.GoalEnv interface
In short, the gym.Env
must have:
- a vectorized implementation of compute_reward()
- a dictionary observation space with three keys: observation
, achieved_goal
and desired_goal
Warning
Because it needs access to env.compute_reward()
HER
must be loaded with the env. If you just want to use the trained policy
without instantiating the environment, we recommend saving the policy only.
Note
Compared to other implementations, the future
goal sampling strategy is inclusive:
the current transition can be used when re-sampling.
Notes¶
Original paper: https://arxiv.org/abs/1707.01495
OpenAI paper: Plappert et al. (2018)
OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/
Can I use?¶
Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.
Example¶
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo repository.
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv
model_class = DQN # works also with SAC, DDPG and TD3
N_BITS = 15
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
# Initialize the model
model = model_class(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
# Parameters for HER
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy=goal_selection_strategy,
),
verbose=1,
)
# Train the model
model.learn(1000)
model.save("./her_bit_env")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
model = model_class.load("./her_bit_env", env=env)
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
if done:
obs = env.reset()
Results¶
This implementation was tested on the parking env using 3 seeds.
The complete learning curves are available in the associated PR #120.
How to replicate the results?¶
Clone the rl-zoo repo:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark:
python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000
Plot the results:
python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million
Parameters¶
HER Replay Buffer¶
- class stable_baselines3.her.HerReplayBuffer(buffer_size, observation_space, action_space, env, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True, n_sampled_goal=4, goal_selection_strategy='future', copy_info_dict=False, online_sampling=None)[source]¶
Hindsight Experience Replay (HER) buffer. Paper: https://arxiv.org/abs/1707.01495
Replay buffer for sampling HER (Hindsight Experience Replay) transitions.
Note
Compared to other implementations, the
future
goal sampling strategy is inclusive: the current transition can be used when re-sampling.- Parameters:
buffer_size (
int
) – Max number of element in the bufferobservation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spaceenv (
VecEnv
) – The training environmentdevice (
Union
[device
,str
]) – PyTorch devicen_envs (
int
) – Number of parallel environmentsoptimize_memory_usage (
bool
) – Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)handle_timeout_termination (
bool
) – Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284n_sampled_goal (
int
) – Number of virtual transitions to create per real transition, by sampling new goals.goal_selection_strategy (
Union
[GoalSelectionStrategy
,str
]) – Strategy for sampling goals for replay. One of [‘episode’, ‘final’, ‘future’]copy_info_dict (
bool
) – Whether to copy the info dictionary and pass it tocompute_reward()
method. Please note that the copy may cause a slowdown. False by default.
- add(obs, next_obs, action, reward, done, infos)[source]¶
Add elements to the buffer.
- Return type:
None
- extend(*args, **kwargs)¶
Add a new batch of transitions to the buffer
- Return type:
None
- reset()¶
Reset the buffer.
- Return type:
None
- sample(batch_size, env=None)[source]¶
Sample elements from the replay buffer.
- Parameters:
batch_size (
int
) – Number of element to sampleenv (
Optional
[VecNormalize
]) – Associated VecEnv to normalize the observations/rewards when sampling
- Return type:
DictReplayBufferSamples
- Returns:
Samples
- size()¶
- Return type:
int
- Returns:
The current size of the buffer
- static swap_and_flatten(arr)¶
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) to convert shape from [n_steps, n_envs, …] (when … is the shape of the features) to [n_steps * n_envs, …] (which maintain the order)
- Parameters:
arr (
ndarray
) –- Return type:
ndarray
- Returns:
- to_torch(array, copy=True)¶
Convert a numpy array to a PyTorch tensor. Note: it copies the data by default
- Parameters:
array (
ndarray
) –copy (
bool
) – Whether to copy or not the data (may be useful to avoid changing things by reference). This argument is inoperative if the device is not the CPU.
- Return type:
Tensor
- Returns: