"""Policies: abstract base class and concrete implementations."""
import collections
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch import nn
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
Distribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution,
make_proba_distribution,
)
from stable_baselines3.common.preprocessing import get_action_dim, maybe_transpose, preprocess_obs
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
class BaseModel(nn.Module, ABC):
"""
The base model object: makes predictions in response to observations.
In the case of policies, the prediction is an action. In the case of critics, it is the
estimated value of the observation.
:param observation_space: The observation space of the environment
:param action_space: The action space of the environment
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[nn.Module] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(BaseModel, self).__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.observation_space = observation_space
self.action_space = action_space
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
self.optimizer = None # type: Optional[th.optim.Optimizer]
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
@abstractmethod
def forward(self, *args, **kwargs):
pass
def _update_features_extractor(
self, net_kwargs: Dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Dict[str, Any]:
"""
Update the network keyword arguments and create a new features extractor object if needed.
If a ``features_extractor`` object is passed, then it will be shared.
:param net_kwargs: the base network keyword arguments, without the ones
related to features extractor
:param features_extractor: a features extractor object.
If None, a new object will be created.
:return: The updated keyword arguments
"""
net_kwargs = net_kwargs.copy()
if features_extractor is None:
# The features extractor is not shared, create a new one
features_extractor = self.make_features_extractor()
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
return net_kwargs
def make_features_extractor(self) -> BaseFeaturesExtractor:
""" Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: th.Tensor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs:
:return:
"""
assert self.features_extractor is not None, "No features extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)
def _get_constructor_parameters(self) -> Dict[str, Any]:
"""
Get data that need to be saved in order to re-create the model when loading it from disk.
:return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child class
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)
@property
def device(self) -> th.device:
"""Infer which device this policy lives on by inspecting its parameters.
If it has no parameters, the 'cpu' device is used as a fallback.
:return:"""
for param in self.parameters():
return param.device
return get_device("cpu")
def save(self, path: str) -> None:
"""
Save model to a given location.
:param path:
"""
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
"""
Load model from path.
:param path:
:param device: Device on which the policy should be loaded.
:return:
"""
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
# Load weights
model.load_state_dict(saved_variables["state_dict"])
model.to(device)
return model
def load_from_vector(self, vector: np.ndarray) -> None:
"""
Load parameters from a 1D vector.
:param vector:
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return:
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
class BasePolicy(BaseModel):
"""The base policy object.
Parameters are mostly the same as `BaseModel`; additions are documented below.
:param args: positional arguments passed through to `BaseModel`.
:param kwargs: keyword arguments passed through to `BaseModel`.
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""
def __init__(self, *args, squash_output: bool = False, **kwargs):
super(BasePolicy, self).__init__(*args, **kwargs)
self._squash_output = squash_output
@staticmethod
def _dummy_schedule(progress_remaining: float) -> float:
""" (float) Useful for pickling policy."""
del progress_remaining
return 0.0
@property
def squash_output(self) -> bool:
"""(bool) Getter for squash_output."""
return self._squash_output
@staticmethod
def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(module.weight, gain=gain)
if module.bias is not None:
module.bias.data.fill_(0.0)
@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
By default provides a dummy implementation -- not all BasePolicy classes
implement this, e.g. if they are a Critic in an Actor-Critic method.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get the policy action and state from an observation (and optional state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
if isinstance(observation, dict):
observation = ObsDictWrapper.convert_dict(observation)
else:
observation = np.array(observation)
# Handle the different cases for images
# as PyTorch use channel first format
observation = maybe_transpose(observation, self.observation_space)
vectorized_env = is_vectorized_observation(observation, self.observation_space)
observation = observation.reshape((-1,) + self.observation_space.shape)
observation = th.as_tensor(observation).to(self.device)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions = actions[0]
return actions, state
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
:param action: Action to scale
:return: Scaled action
"""
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
:param scaled_action: Action to un-scale
"""
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
[docs]class ActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super(ActorCriticPolicy, self).__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
)
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
else:
net_arch = []
self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": sde_net_arch is not None,
}
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self._build(lr_schedule)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
sde_net_arch=default_none_kwargs["sde_net_arch"],
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
[docs] def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
:param n_envs:
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device
)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self._build_mlp_extractor()
latent_dim_pi = self.mlp_extractor.latent_dim_pi
# Separate features extractor for gSDE
if self.sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
self.features_dim, self.sde_net_arch, self.activation_fn
)
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, CategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, BernoulliDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1,
}
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
[docs] def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Get the latent code (i.e., activations of the last layer of each network)
for the different networks.
:param obs: Observation
:return: Latent codes
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Features for sde
latent_sde = latent_pi
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return latent_pi, latent_vf, latent_sde
def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:param latent_sde: Latent code for the gSDE exploration function
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
latent_pi, _, latent_sde = self._get_latent(observation)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
return distribution.get_actions(deterministic=deterministic)
[docs] def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
[docs]class ActorCriticCnnPolicy(ActorCriticPolicy):
"""
CNN policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(ActorCriticCnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class ContinuousCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Only predict the Q-value using the first network.
This allows to reduce computation when all the estimates are not needed
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs)
return self.q_networks[0](th.cat([features, actions], dim=1))
def create_sde_features_extractor(
features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]
) -> Tuple[nn.Sequential, int]:
"""
Create the neural network that will be used to extract features
for the gSDE exploration function.
:param features_dim:
:param sde_net_arch:
:param activation_fn:
:return:
"""
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
# don't use any activation function
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
sde_features_extractor = nn.Sequential(*latent_sde_net)
return sde_features_extractor, latent_sde_dim
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
"""
Returns the registered policy from the base type and name.
See `register_policy` for registering policies and explanation.
:param base_policy_type: the base policy class
:param name: the policy name
:return: the policy
"""
if base_policy_type not in _policy_registry:
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise KeyError(
f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
)
return _policy_registry[base_policy_type][name]
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
"""
Register a policy, so it can be called using its name.
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
Consider following:
OnlinePolicy
-- OnlineMlpPolicy ("MlpPolicy")
-- OnlineCnnPolicy ("CnnPolicy")
OfflinePolicy
-- OfflineMlpPolicy ("MlpPolicy")
-- OfflineCnnPolicy ("CnnPolicy")
Two policies have name "MlpPolicy" and two have "CnnPolicy".
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
is given and used to select and return the correct policy.
:param name: the policy name
:param policy: the policy class
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
if sub_class not in _policy_registry:
_policy_registry[sub_class] = {}
if name in _policy_registry[sub_class]:
# Check if the registered policy is same
# we try to register. If not so,
# do not override and complain.
if _policy_registry[sub_class][name] != policy:
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
_policy_registry[sub_class][name] = policy