Custom Policy Network¶
Stable Baselines3 provides policy networks for images (CnnPolicies) and other type of input features (MlpPolicies).
Warning
For A2C and PPO, continuous actions are clipped during training and testing
(to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a tanh()
transformation,
which handles bounds more correctly.
SB3 Policy¶
SB3 networks are separated into two mains parts (see figure below):
A features extractor (usually shared between actor and critic when applicable, to save computation) whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images. This is the
features_extractor_class
parameter. You can change the default parameters of that features extractor by passing afeatures_extractor_kwargs
parameter.A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the
net_arch
parameter.
Note
All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, …) before being fed to the features extractor.
In the case of vector observations, the features extractor is just a Flatten
layer.
SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together with the associated optimizers.
Each of these network have a features extractor followed by a fully-connected network.
Note
When we refer to “policy” in Stable-Baselines3, this is usually an abuse of language compared to RL terminology. In SB3, “policy” refers to the class that handles all the networks useful for training, so not only the network used to predict actions (the “learned controller”).
Custom Network Architecture¶
One way of customising the policy network architecture is to pass arguments when creating the model,
using policy_kwargs
parameter:
import gym
import torch as th
from stable_baselines3 import PPO
# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[dict(pi=[32, 32], vf=[32, 32])])
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=100000)
# Save the agent
model.save("ppo_cartpole")
del model
# the policy_kwargs are automatically loaded
model = PPO.load("ppo_cartpole", env=env)
Custom Feature Extractor¶
If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
that derives from BaseFeaturesExtractor
and then pass it to the model when training.
Note
By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed by defining a custom policy for on-policy algorithms or setting
share_features_extractor=False
in the policy_kwargs
for off-policy algorithms
(and when applicable).
import gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)
On-Policy Algorithms¶
Off-Policy Algorithms¶
If you need a network architecture that is different for the actor and the critic when using SAC
, DDPG
or TD3
,
you can pass a dictionary of the following structure: dict(qf=[<critic network architecture>], pi=[<actor network architecture>])
.
For example, if you want a different architecture for the actor (aka pi
) and the critic (Q-function aka qf
) networks,
then you can specify net_arch=dict(qf=[400, 300], pi=[64, 64])
.
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify net_arch=[256, 256]
(here, two hidden layers of 256 units each).
Note
Compared to their on-policy counterparts, no shared layers (other than the feature extractor) between the actor and the critic are allowed (to prevent issues with target networks).
from stable_baselines3 import SAC
# Custom actor architecture with two layers of 64 units each
# Custom critic architecture with two layers of 400 and 300 units
policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
# Create the agent
model = SAC("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)