Source code for stable_baselines3.a2c.a2c

from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union

import torch as th
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance

SelfA2C = TypeVar("SelfA2C", bound="A2C")

[docs]class A2C(OnPolicyAlgorithm): """ Advantage Actor Critic (A2C) Paper: Code: This implementation borrows code from and and Stable Baselines ( Introduction to A2C: :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1. :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator of RMSProp update :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param normalize_advantage: Whether to normalize or not the advantage :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, } def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 7e-4, n_steps: int = 5, gamma: float = 0.99, gae_lambda: float = 1.0, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, rms_prop_eps: float = 1e-5, use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, rollout_buffer_class=rollout_buffer_class, rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, device=device, seed=seed, _init_setup_model=False, supported_action_spaces=( spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, ), ) self.normalize_advantage = normalize_advantage # Update optimizer inside the policy if we want to use RMSProp # (original implementation) rather than Adam if use_rms_prop and "optimizer_class" not in self.policy_kwargs: self.policy_kwargs["optimizer_class"] = th.optim.RMSprop self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0) if _init_setup_model: self._setup_model()
[docs] def train(self) -> None: """ Update policy using the currently gathered rollout buffer (one gradient step over whole data). """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy gradient loss policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) self._n_updates += 1 self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/explained_variance", explained_var) self.logger.record("train/entropy_loss", entropy_loss.item()) self.logger.record("train/policy_loss", policy_loss.item()) self.logger.record("train/value_loss", value_loss.item()) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
[docs] def learn( self: SelfA2C, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "A2C", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfA2C: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )