Source code for stable_baselines3.her.her_replay_buffer

from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch as th
from gym import spaces

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy


[docs]class HerReplayBuffer(ReplayBuffer): """ Replay buffer for sampling HER (Hindsight Experience Replay) transitions. In the online sampling case, these new transitions will not be saved in the replay buffer and will only be created at sampling time. :param env: The training environment :param buffer_size: The size of the buffer measured in transitions. :param max_episode_length: The length of an episode. (time horizon) :param goal_selection_strategy: Strategy for sampling goals for replay. One of ['episode', 'final', 'future'] :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param n_envs: Number of parallel environments :her_ratio: The ratio between HER transitions and regular transitions in percent (between 0 and 1, for online sampling) The default value ``her_ratio=0.8`` corresponds to 4 virtual transitions for one real transition (4 / (4 + 1) = 0.8) """ def __init__( self, env: ObsDictWrapper, buffer_size: int, max_episode_length: int, goal_selection_strategy: GoalSelectionStrategy, observation_space: spaces.Space, action_space: spaces.Space, device: Union[th.device, str] = "cpu", n_envs: int = 1, her_ratio: float = 0.8, ): super(HerReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs) self.env = env self.buffer_size = buffer_size self.max_episode_length = max_episode_length # buffer with episodes # number of episodes which can be stored until buffer size is reached self.max_episode_stored = self.buffer_size // self.max_episode_length self.current_idx = 0 # input dimensions for buffer initialization input_shape = { "observation": (self.env.num_envs, self.env.obs_dim), "achieved_goal": (self.env.num_envs, self.env.goal_dim), "desired_goal": (self.env.num_envs, self.env.goal_dim), "action": (self.action_dim,), "reward": (1,), "next_obs": (self.env.num_envs, self.env.obs_dim), "next_achieved_goal": (self.env.num_envs, self.env.goal_dim), "next_desired_goal": (self.env.num_envs, self.env.goal_dim), "done": (1,), } self.buffer = { key: np.zeros((self.max_episode_stored, self.max_episode_length, *dim), dtype=np.float32) for key, dim in input_shape.items() } # Store info dicts are it can be used to compute the reward (e.g. continuity cost) self.info_buffer = [deque(maxlen=self.max_episode_length) for _ in range(self.max_episode_stored)] # episode length storage, needed for episodes which has less steps than the maximum length self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64) self.goal_selection_strategy = goal_selection_strategy # percentage of her indices self.her_ratio = her_ratio def __getstate__(self) -> Dict[str, Any]: """ Gets state for pickling. Excludes self.env, as in general Env's may not be pickleable.""" state = self.__dict__.copy() # these attributes are not pickleable del state["env"] return state def __setstate__(self, state: Dict[str, Any]) -> None: """ Restores pickled state. User must call ``set_env()`` after unpickling before using. :param state: """ self.__dict__.update(state) assert "env" not in state self.env = None
[docs] def set_env(self, env: ObsDictWrapper) -> None: """ Sets the environment. :param env: """ if self.env is not None: raise ValueError("Trying to set env of already initialized environment.") self.env = env
def _get_samples( self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None ) -> Union[ReplayBufferSamples, RolloutBufferSamples]: """ Abstract method from base class. """ raise NotImplementedError()
[docs] def sample( self, batch_size: int, env: Optional[VecNormalize], ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: """ Sample function for online sampling of HER transition, this replaces the "regular" replay buffer ``sample()`` method in the ``train()`` function. :param batch_size: Number of element to sample :param env: Associated gym VecEnv to normalize the observations/rewards when sampling :return: Samples. """ return self._sample_transitions(batch_size, maybe_vec_env=env, online_sampling=True)
[docs] def sample_offline( self, n_sampled_goal: Optional[int] = None, ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: """ Sample function for offline sampling of HER transition, in that case, only one episode is used and transitions are added to the regular replay buffer. :param n_sampled_goal: Number of sampled goals for replay :return: at most(n_sampled_goal * episode_length) HER transitions. """ # env=None as we should store unnormalized transitions, they will be normalized at sampling time return self._sample_transitions( batch_size=None, maybe_vec_env=None, online_sampling=False, n_sampled_goal=n_sampled_goal )
[docs] def sample_goals( self, episode_indices: np.ndarray, her_indices: np.ndarray, transitions_indices: np.ndarray, ) -> np.ndarray: """ Sample goals based on goal_selection_strategy. This is a vectorized (fast) version. :param episode_indices: Episode indices to use. :param her_indices: HER indices. :param transitions_indices: Transition indices to use. :return: Return sampled goals. """ her_episode_indices = episode_indices[her_indices] if self.goal_selection_strategy == GoalSelectionStrategy.FINAL: # replay with final state of current episode transitions_indices = self.episode_lengths[her_episode_indices] - 1 elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # replay with random state which comes from the same episode and was observed after current transition transitions_indices = np.random.randint( transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices] ) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: # replay with random state which comes from the same episode as current transition transitions_indices = np.random.randint(self.episode_lengths[her_episode_indices]) else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") return self.buffer["achieved_goal"][her_episode_indices, transitions_indices]
def _sample_transitions( self, batch_size: Optional[int], maybe_vec_env: Optional[VecNormalize], online_sampling: bool, n_sampled_goal: Optional[int] = None, ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: """ :param batch_size: Number of element to sample (only used for online sampling) :param env: associated gym VecEnv to normalize the observations/rewards Only valid when using online sampling :param online_sampling: Using online_sampling for HER or not. :param n_sampled_goal: Number of sampled goals for replay. (offline sampling) :return: Samples. """ # Select which episodes to use if online_sampling: assert batch_size is not None, "No batch_size specified for online sampling of HER transitions" # Do not sample the episode with index `self.pos` as the episode is invalid if self.full: episode_indices = ( np.random.randint(1, self.n_episodes_stored, batch_size) + self.pos ) % self.n_episodes_stored else: episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size) # A subset of the transitions will be relabeled using HER algorithm her_indices = np.arange(batch_size)[: int(self.her_ratio * batch_size)] else: assert maybe_vec_env is None, "Transitions must be stored unnormalized in the replay buffer" assert n_sampled_goal is not None, "No n_sampled_goal specified for offline sampling of HER transitions" # Offline sampling: there is only one episode stored episode_length = self.episode_lengths[0] # we sample n_sampled_goal per timestep in the episode (only one is stored). episode_indices = np.tile(0, (episode_length * n_sampled_goal)) # we only sample virtual transitions # as real transitions are already stored in the replay buffer her_indices = np.arange(len(episode_indices)) ep_lengths = self.episode_lengths[episode_indices] # Special case when using the "future" goal sampling strategy # we cannot sample all transitions, we have to remove the last timestep if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # restrict the sampling domain when ep_lengths > 1 # otherwise filter out the indices her_indices = her_indices[ep_lengths[her_indices] > 1] ep_lengths[her_indices] -= 1 if online_sampling: # Select which transitions to use transitions_indices = np.random.randint(ep_lengths) else: if her_indices.size == 0: # Episode of one timestep, not enough for using the "future" strategy # no virtual transitions are created in that case return np.zeros(0), np.zeros(0), np.zeros(0), np.zeros(0) else: # Repeat every transition index n_sampled_goals times # to sample n_sampled_goal per timestep in the episode (only one is stored). # Now with the corrected episode length when using "future" strategy transitions_indices = np.tile(np.arange(ep_lengths[0]), n_sampled_goal) episode_indices = episode_indices[transitions_indices] her_indices = np.arange(len(episode_indices)) # get selected transitions transitions = {key: self.buffer[key][episode_indices, transitions_indices].copy() for key in self.buffer.keys()} # sample new desired goals and relabel the transitions new_goals = self.sample_goals(episode_indices, her_indices, transitions_indices) transitions["desired_goal"][her_indices] = new_goals # Convert info buffer to numpy array transitions["info"] = np.array( [ self.info_buffer[episode_idx][transition_idx] for episode_idx, transition_idx in zip(episode_indices, transitions_indices) ] ) # Vectorized computation of the new reward transitions["reward"][her_indices, 0] = self.env.env_method( "compute_reward", # the new state depends on the previous state and action # s_{t+1} = f(s_t, a_t) # so the next_achieved_goal depends also on the previous state and action # because we are in a GoalEnv: # r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal) # therefore we have to use "next_achieved_goal" and not "achieved_goal" transitions["next_achieved_goal"][her_indices, 0], # here we use the new desired goal transitions["desired_goal"][her_indices, 0], transitions["info"][her_indices, 0], ) # concatenate observation with (desired) goal observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env)) # HACK to make normalize obs work with the next observation transitions["observation"] = transitions["next_obs"] next_observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env)) if online_sampling: data = ( observations[:, 0], transitions["action"], next_observations[:, 0], transitions["done"], self._normalize_reward(transitions["reward"], maybe_vec_env), ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) else: return observations, next_observations, transitions["action"], transitions["reward"]
[docs] def add( self, obs: Dict[str, np.ndarray], next_obs: Dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: List[dict], ) -> None: if self.current_idx == 0 and self.full: # Clear info buffer self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length) self.buffer["observation"][self.pos][self.current_idx] = obs["observation"] self.buffer["achieved_goal"][self.pos][self.current_idx] = obs["achieved_goal"] self.buffer["desired_goal"][self.pos][self.current_idx] = obs["desired_goal"] self.buffer["action"][self.pos][self.current_idx] = action self.buffer["done"][self.pos][self.current_idx] = done self.buffer["reward"][self.pos][self.current_idx] = reward self.buffer["next_obs"][self.pos][self.current_idx] = next_obs["observation"] self.buffer["next_achieved_goal"][self.pos][self.current_idx] = next_obs["achieved_goal"] self.buffer["next_desired_goal"][self.pos][self.current_idx] = next_obs["desired_goal"] self.info_buffer[self.pos].append(infos) # update current pointer self.current_idx += 1
[docs] def store_episode(self) -> None: """ Increment episode counter and reset transition pointer. """ # add episode length to length storage self.episode_lengths[self.pos] = self.current_idx # update current episode pointer # Note: in the OpenAI implementation # when the buffer is full, the episode replaced # is randomly chosen self.pos += 1 if self.pos == self.max_episode_stored: self.full = True self.pos = 0 # reset transition pointer self.current_idx = 0
@property def n_episodes_stored(self) -> int: if self.full: return self.max_episode_stored return self.pos
[docs] def size(self) -> int: """ :return: The current number of transitions in the buffer. """ return int(np.sum(self.episode_lengths))
[docs] def reset(self) -> None: """ Reset the buffer. """ self.pos = 0 self.current_idx = 0 self.full = False self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64)