Source code for stable_baselines3.her.her

import io
import pathlib
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

import numpy as np
import torch as th

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_zip_file, recursive_setattr
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, TrainFreq
from stable_baselines3.common.utils import check_for_correct_spaces, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer


def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> int:
    """
    Get time limit from environment.

    :param env: Environment from which we want to get the time limit.
    :param current_max_episode_length: Current value for max_episode_length.
    :return: max episode length
    """
    # try to get the attribute from environment
    if current_max_episode_length is None:
        try:
            current_max_episode_length = env.get_attr("spec")[0].max_episode_steps
            # Raise the error because the attribute is present but is None
            if current_max_episode_length is None:
                raise AttributeError
        # if not available check if a valid value was passed as an argument
        except AttributeError:
            raise ValueError(
                "The max episode length could not be inferred.\n"
                "You must specify a `max_episode_steps` when registering the environment,\n"
                "use a `gym.wrappers.TimeLimit` wrapper "
                "or pass `max_episode_length` to the model constructor"
            )
    return current_max_episode_length


# TODO: rewrite HER class as soon as dict obs are supported
[docs]class HER(BaseAlgorithm): """ Hindsight Experience Replay (HER) Paper: https://arxiv.org/abs/1707.01495 .. warning:: For performance reasons, the maximum number of steps per episodes must be specified. In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None). Otherwise, you can directly pass ``max_episode_length`` to the model constructor For additional offline algorithm specific arguments please have a look at the corresponding documentation. :param policy: The policy model to use. :param env: The environment to learn from (if registered in Gym, can be str) :param model_class: Off policy model which will be used with hindsight experience replay. (SAC, TD3, DDPG, DQN) :param n_sampled_goal: Number of sampled goals for replay. (offline sampling) :param goal_selection_strategy: Strategy for sampling goals for replay. One of ['episode', 'final', 'future', 'random'] :param online_sampling: Sample HER transitions online. :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param max_episode_length: The maximum length of an episode. If not specified, it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper. """ def __init__( self, policy: Union[str, Type[BasePolicy]], env: Union[GymEnv, str], model_class: Type[OffPolicyAlgorithm], n_sampled_goal: int = 4, goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", online_sampling: bool = False, max_episode_length: Optional[int] = None, *args, **kwargs, ): # we will use the policy and learning rate from the model super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0) del self.policy, self.learning_rate if self.get_vec_normalize_env() is not None: assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`" _init_setup_model = kwargs.get("_init_setup_model", True) if "_init_setup_model" in kwargs: del kwargs["_init_setup_model"] # model initialization self.model_class = model_class self.model = model_class( policy=policy, env=self.env, _init_setup_model=False, # pytype: disable=wrong-keyword-args *args, **kwargs, # pytype: disable=wrong-keyword-args ) # Make HER use self.model.action_noise del self.action_noise self.verbose = self.model.verbose self.tensorboard_log = self.model.tensorboard_log # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] else: self.goal_selection_strategy = goal_selection_strategy # check if goal_selection_strategy is valid assert isinstance( self.goal_selection_strategy, GoalSelectionStrategy ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" self.n_sampled_goal = n_sampled_goal # if we sample her transitions online use custom replay buffer self.online_sampling = online_sampling # compute ratio between HER replays and regular replays in percent for online HER sampling self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) # maximum steps in episode self.max_episode_length = get_time_limit(self.env, max_episode_length) # storage for transitions of current episode for offline sampling # for online sampling, it replaces the "classic" replay buffer completely her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env." self._episode_storage = HerReplayBuffer( self.env, her_buffer_size, self.max_episode_length, self.goal_selection_strategy, self.env.observation_space, self.env.action_space, self.device, self.n_envs, self.her_ratio, # pytype: disable=wrong-arg-types ) # counter for steps in episode self.episode_steps = 0 if _init_setup_model: self._setup_model() def _setup_model(self) -> None: self.model._setup_model() # assign episode storage to replay buffer when using online HER sampling if self.online_sampling: self.model.replay_buffer = self._episode_storage
[docs] def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: return self.model.predict(observation, state, mask, deterministic)
[docs] def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "HER", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, ) -> BaseAlgorithm: total_timesteps, callback = self._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name ) self.model.start_time = self.start_time self.model.ep_info_buffer = self.ep_info_buffer self.model.ep_success_buffer = self.ep_success_buffer self.model.num_timesteps = self.num_timesteps self.model._episode_num = self._episode_num self.model._last_obs = self._last_obs self.model._total_timesteps = self._total_timesteps callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, train_freq=self.train_freq, action_noise=self.action_noise, callback=callback, learning_starts=self.learning_starts, log_interval=log_interval, ) if rollout.continue_training is False: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts and self.replay_buffer.size() > 0: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) callback.on_training_end() return self
[docs] def collect_rollouts( self, env: VecEnv, callback: BaseCallback, train_freq: TrainFreq, action_noise: Optional[ActionNoise] = None, learning_starts: int = 0, log_interval: Optional[int] = None, ) -> RolloutReturn: """ Collect experiences and store them into a ReplayBuffer. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param train_freq: How much experience to collect by doing rollouts of current policy. Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)`` or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)`` with ``<n>`` being an integer greater than 0. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param log_interval: Log data every ``log_interval`` episodes :return: """ episode_rewards, total_timesteps = [], [] num_collected_steps, num_collected_episodes = 0, 0 assert isinstance(env, VecEnv), "You must pass a VecEnv" assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment" assert train_freq.frequency > 0, "Should at least collect one step or episode." if self.model.use_sde: self.actor.reset_noise() callback.on_rollout_start() continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): done = False episode_reward, episode_timesteps = 0.0, 0 while not done: # concatenate observation and (desired) goal observation = self._last_obs self._last_obs = ObsDictWrapper.convert_dict(observation) if ( self.model.use_sde and self.model.sde_sample_freq > 0 and num_collected_steps % self.model.sde_sample_freq == 0 ): # Sample a new noise matrix self.actor.reset_noise() # Select action randomly or according to policy self.model._last_obs = self._last_obs action, buffer_action = self._sample_action(learning_starts, action_noise) # Perform action new_obs, reward, done, infos = env.step(action) self.num_timesteps += 1 self.model.num_timesteps = self.num_timesteps episode_timesteps += 1 num_collected_steps += 1 # Only stop training if return value is False, not when it is None. if callback.on_step() is False: return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False) episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper self._update_info_buffer(infos, done) self.model.ep_info_buffer = self.ep_info_buffer self.model.ep_success_buffer = self.ep_success_buffer # == Store transition in the replay buffer and/or in the episode storage == if self._vec_normalize_env is not None: # Store only the unnormalized version new_obs_ = self._vec_normalize_env.get_original_obs() reward_ = self._vec_normalize_env.get_original_reward() else: # Avoid changing the original ones self._last_original_obs, new_obs_, reward_ = observation, new_obs, reward self.model._last_original_obs = self._last_original_obs # As the VecEnv resets automatically, new_obs is already the # first observation of the next episode if done and infos[0].get("terminal_observation") is not None: next_obs = infos[0]["terminal_observation"] # VecNormalize normalizes the terminal observation if self._vec_normalize_env is not None: next_obs = self._vec_normalize_env.unnormalize_obs(next_obs) else: next_obs = new_obs_ if self.online_sampling: self.replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) else: # concatenate observation with (desired) goal flattened_obs = ObsDictWrapper.convert_dict(self._last_original_obs) flattened_next_obs = ObsDictWrapper.convert_dict(next_obs) # add to replay buffer self.replay_buffer.add(flattened_obs, flattened_next_obs, buffer_action, reward_, done) # add current transition to episode storage self._episode_storage.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) self._last_obs = new_obs self.model._last_obs = self._last_obs # Save the unnormalized new observation if self._vec_normalize_env is not None: self._last_original_obs = new_obs_ self.model._last_original_obs = self._last_original_obs self.model._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) # For DQN, check if the target network should be updated # and update the exploration schedule # For SAC/TD3, the update is done as the same time as the gradient update # see https://github.com/hill-a/stable-baselines/issues/900 self.model._on_step() self.episode_steps += 1 if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): break if done or self.episode_steps >= self.max_episode_length: if self.online_sampling: self.replay_buffer.store_episode() else: self._episode_storage.store_episode() # sample virtual transitions and store them in replay buffer self._sample_her_transitions() # clear storage for current episode self._episode_storage.reset() num_collected_episodes += 1 self._episode_num += 1 self.model._episode_num = self._episode_num episode_rewards.append(episode_reward) total_timesteps.append(episode_timesteps) if action_noise is not None: action_noise.reset() # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() self.episode_steps = 0 mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0 callback.on_rollout_end() return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training)
def _sample_her_transitions(self) -> None: """ Sample additional goals and store new transitions in replay buffer when using offline sampling. """ # Sample goals and get new observations # maybe_vec_env=None as we should store unnormalized transitions, # they will be normalized at sampling time observations, next_observations, actions, rewards = self._episode_storage.sample_offline( n_sampled_goal=self.n_sampled_goal ) # store data in replay buffer dones = np.zeros((len(observations)), dtype=bool) self.replay_buffer.extend(observations, next_observations, actions, rewards, dones) def __getattr__(self, item: str) -> Any: """ Find attribute from model class if this class does not have it. """ if hasattr(self.model, item): return getattr(self.model, item) else: raise AttributeError(f"{self} has no attribute {item}") def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: return self.model._get_torch_save_params()
[docs] def save( self, path: Union[str, pathlib.Path, io.BufferedIOBase], exclude: Optional[Iterable[str]] = None, include: Optional[Iterable[str]] = None, ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. :param path: path to the file where the rl agent should be saved :param exclude: name of parameters that should be excluded in addition to the default one :param include: name of parameters that might be excluded but should be included anyway """ # add HER parameters to model self.model.n_sampled_goal = self.n_sampled_goal self.model.goal_selection_strategy = self.goal_selection_strategy self.model.online_sampling = self.online_sampling self.model.model_class = self.model_class self.model.max_episode_length = self.max_episode_length self.model.save(path, exclude, include)
[docs] @classmethod def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) # check if observation space and action space are part of the saved parameters if "observation_space" not in data or "action_space" not in data: raise KeyError("The observation_space and action_space were not given, can't verify new environments") # check if given env is valid if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] if "use_sde" in data and data["use_sde"]: kwargs["use_sde"] = True # Keys that cannot be changed for key in {"model_class", "online_sampling", "max_episode_length"}: if key in kwargs: del kwargs[key] # Keys that can be changed for key in {"n_sampled_goal", "goal_selection_strategy"}: if key in kwargs: data[key] = kwargs[key] # pytype: disable=unsupported-operands del kwargs[key] # noinspection PyArgumentList her_model = cls( policy=data["policy_class"], env=env, model_class=data["model_class"], n_sampled_goal=data["n_sampled_goal"], goal_selection_strategy=data["goal_selection_strategy"], online_sampling=data["online_sampling"], max_episode_length=data["max_episode_length"], policy_kwargs=data["policy_kwargs"], _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args **kwargs, ) # load parameters her_model.model.__dict__.update(data) her_model.model.__dict__.update(kwargs) her_model._setup_model() her_model._total_timesteps = her_model.model._total_timesteps her_model.num_timesteps = her_model.model.num_timesteps her_model._episode_num = her_model.model._episode_num # put state_dicts back in place her_model.model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: recursive_setattr(her_model.model, name, pytorch_variables[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if her_model.model.use_sde: her_model.model.policy.reset_noise() # pytype: disable=attribute-error return her_model
[docs] def load_replay_buffer( self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_trajectory: bool = True ) -> None: """ Load a replay buffer from a pickle file and set environment for replay buffer (only online sampling). :param path: Path to the pickled replay buffer. :param truncate_last_trajectory: Only for online sampling. If set to ``True`` we assume that the last trajectory in the replay buffer was finished. If it is set to ``False`` we assume that we continue the same trajectory (same episode). """ self.model.load_replay_buffer(path=path) if self.online_sampling: # set environment self.replay_buffer.set_env(self.env) # If we are at the start of an episode, no need to truncate current_idx = self.replay_buffer.current_idx # truncate interrupted episode if truncate_last_trajectory and current_idx > 0: warnings.warn( "The last trajectory in the replay buffer will be truncated.\n" "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) # get current episode and transition index pos = self.replay_buffer.pos # set episode length for current episode self.replay_buffer.episode_lengths[pos] = current_idx # set done = True for current episode # current_idx was already incremented self.replay_buffer.buffer["done"][pos][current_idx - 1] = np.array([True], dtype=np.float32) # reset current transition index self.replay_buffer.current_idx = 0 # increment episode counter self.replay_buffer.pos = (self.replay_buffer.pos + 1) % self.replay_buffer.max_episode_stored # update "full" indicator self.replay_buffer.full = self.replay_buffer.full or self.replay_buffer.pos == 0