Source code for stable_baselines3.dqn.dqn

import warnings
from typing import Any, ClassVar, Optional, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

SelfDQN = TypeVar("SelfDQN", bound="DQN")


[docs] class DQN(OffPolicyAlgorithm): """ Deep Q-Network (DQN) Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param target_update_interval: update the target network every ``target_update_interval`` environment steps. :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced :param exploration_initial_eps: initial value of random action probability :param exploration_final_eps: final value of random action probability :param max_grad_norm: The maximum value for the gradient clipping :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } # Linear schedule will be defined in `_setup_model()` exploration_schedule: Schedule q_net: QNetwork q_net_target: QNetwork policy: DQNPolicy def __init__( self, policy: Union[str, type[DQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, max_grad_norm: float = 10, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ) -> None: super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=None, # No action noise replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(spaces.Discrete,), support_multi_env=True, ) self.exploration_initial_eps = exploration_initial_eps self.exploration_final_eps = exploration_final_eps self.exploration_fraction = exploration_fraction self.target_update_interval = target_update_interval # For updating the target network with multiple envs: self._n_calls = 0 self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Copy running stats, see GH issue #996 self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"]) self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction, ) if self.n_envs > 1: if self.n_envs > self.target_update_interval: warnings.warn( "The number of environments used is greater than the target network " f"update interval ({self.n_envs} > {self.target_update_interval}), " "therefore the target network will be updated after each call to env.step() " f"which corresponds to {self.n_envs} steps." ) def _create_aliases(self) -> None: self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target def _on_step(self) -> None: """ Update the exploration rate and target network if needed. This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 # Account for multiple environments # each call to step() corresponds to n_envs transitions if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.logger.record("rollout/exploration_rate", self.exploration_rate)
[docs] def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] with th.no_grad(): # Compute the next Q-values using the target network next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values # Get current Q-values estimates current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/loss", np.mean(losses))
[docs] def predict( self, observation: Union[np.ndarray, dict[str, np.ndarray]], state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: action = np.array(self.action_space.sample()) else: action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state
[docs] def learn( self: SelfDQN, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfDQN: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )
def _excluded_save_params(self) -> list[str]: return [*super()._excluded_save_params(), "q_net", "q_net_target"] def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, []