HER

Hindsight Experience Replay (HER)

HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example). HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout. It creates “virtual” transitions by relabeling transitions (changing the desired goal) from past episodes.

Warning

Starting from Stable Baselines3 v1.1.0, HER is no longer a separate algorithm but a replay buffer class HerReplayBuffer that must be passed to an off-policy algorithm when using MultiInputPolicy (to have Dict observation support).

Warning

HER requires the environment to follow the legacy gym_robotics.GoalEnv interface In short, the gym.Env must have: - a vectorized implementation of compute_reward() - a dictionary observation space with three keys: observation, achieved_goal and desired_goal

Warning

Because it needs access to env.compute_reward() HER must be loaded with the env. If you just want to use the trained policy without instantiating the environment, we recommend saving the policy only.

Note

Compared to other implementations, the future goal sampling strategy is inclusive: the current transition can be used when re-sampling.

Notes

Can I use?

Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.

Example

This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo repository.

from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv

model_class = DQN  # works also with SAC, DDPG and TD3
N_BITS = 15

env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)

# Available strategies (cf paper): future, final, episode
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE

# Initialize the model
model = model_class(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    # Parameters for HER
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy=goal_selection_strategy,
    ),
    verbose=1,
)

# Train the model
model.learn(1000)

model.save("./her_bit_env")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
model = model_class.load("./her_bit_env", env=env)

obs = env.reset()
for _ in range(100):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, _ = env.step(action)

    if done:
        obs = env.reset()

Results

This implementation was tested on the parking env using 3 seeds.

The complete learning curves are available in the associated PR #120.

How to replicate the results?

Clone the rl-zoo repo:

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark:

python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000

Plot the results:

python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million

Parameters

HER Replay Buffer

class stable_baselines3.her.HerReplayBuffer(buffer_size, observation_space, action_space, env, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True, n_sampled_goal=4, goal_selection_strategy='future', copy_info_dict=False)[source]

Hindsight Experience Replay (HER) buffer. Paper: https://arxiv.org/abs/1707.01495

Replay buffer for sampling HER (Hindsight Experience Replay) transitions.

Note

Compared to other implementations, the future goal sampling strategy is inclusive: the current transition can be used when re-sampling.

Parameters:
  • buffer_size (int) – Max number of element in the buffer

  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • env (VecEnv) – The training environment

  • device (Union[device, str]) – PyTorch device

  • n_envs (int) – Number of parallel environments

  • optimize_memory_usage (bool) – Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)

  • handle_timeout_termination (bool) – Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284

  • n_sampled_goal (int) – Number of virtual transitions to create per real transition, by sampling new goals.

  • goal_selection_strategy (Union[GoalSelectionStrategy, str]) – Strategy for sampling goals for replay. One of [‘episode’, ‘final’, ‘future’]

  • copy_info_dict (bool) – Whether to copy the info dictionary and pass it to compute_reward() method. Please note that the copy may cause a slowdown. False by default.

add(obs, next_obs, action, reward, done, infos)[source]

Add elements to the buffer.

Return type:

None

extend(*args, **kwargs)

Add a new batch of transitions to the buffer

Return type:

None

reset()

Reset the buffer.

Return type:

None

sample(batch_size, env=None)[source]

Sample elements from the replay buffer.

Parameters:
  • batch_size (int) – Number of element to sample

  • env (Optional[VecNormalize]) – Associated VecEnv to normalize the observations/rewards when sampling

Return type:

DictReplayBufferSamples

Returns:

Samples

set_env(env)[source]

Sets the environment.

Parameters:

env (VecEnv) –

Return type:

None

size()
Return type:

int

Returns:

The current size of the buffer

static swap_and_flatten(arr)

Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) to convert shape from [n_steps, n_envs, …] (when … is the shape of the features) to [n_steps * n_envs, …] (which maintain the order)

Parameters:

arr (ndarray) –

Return type:

ndarray

Returns:

to_torch(array, copy=True)

Convert a numpy array to a PyTorch tensor. Note: it copies the data by default

Parameters:
  • array (ndarray) –

  • copy (bool) – Whether to copy or not the data (may be useful to avoid changing things by reference). This argument is inoperative if the device is not the CPU.

Return type:

Tensor

Returns:

truncate_last_trajectory()[source]

If called, we assume that the last trajectory in the replay buffer was finished (and truncate it). If not called, we assume that we continue the same trajectory (same episode).

Return type:

None

Goal Selection Strategies

class stable_baselines3.her.GoalSelectionStrategy(value)[source]

The strategies for selecting new goals when creating artificial transitions.