
Weights & Biases

Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.

The full documentation is available here:

import gymnasium as gym
import wandb
from wandb.integration.sb3 import WandbCallback

from stable_baselines3 import PPO

config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 25000,
    "env_id": "CartPole-v1",
run = wandb.init(
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    # monitor_gym=True,  # auto-upload the videos of agents playing the game
    # save_code=True,  # optional

model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{}")

Hugging Face πŸ€—οƒ

The Hugging Face Hub πŸ€— is a central place where anyone can share and explore models. It allows you to host your saved models πŸ’Ύ.

You can see the list of stable-baselines3 saved models here: Most of them are available via the RL Zoo.

Official pre-trained models are saved in the SB3 organization on the hub:

We wrote a tutorial on how to use πŸ€— Hub and Stable-Baselines3 here.


pip install huggingface_sb3


If you use the RL Zoo, pushing/loading models from the hub are already integrated:

# Download model and save it into the logs/ folder
python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2  -f logs/
# Push model, config and hyperparameters to the hub
python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v2 -f logs/ -orga sb3 -m "Initial commit"

Download a model from the Hub

You need to copy the repo-id that contains your saved model. For instance sb3/demo-hf-CartPole-v1:

import gymnasium as gym

from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
model = PPO.load(checkpoint)

# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
    model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

You need to define two parameters:

  • repo-id: the name of the Hugging Face repo you want to download.

  • filename: the file you want to download.

Upload a model to the Hub

You can easily upload your models using two different functions:

  1. package_to_hub(): save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.

  2. push_to_hub(): simply push a file to the Hub.

First, you need to be logged in to Hugging Face to upload a model:

  • If you’re using Colab/Jupyter Notebooks:

from huggingface_hub import notebook_login
  • Otherwise:

huggingface-cli login

Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo sb3/demo-hf-CartPole-v1

With package_to_hub()

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

from huggingface_sb3 import package_to_hub

# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)

# Create the evaluation environment
eval_env = make_vec_env(env_id, n_envs=1)

# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)

# Train the agent

# This method save, evaluate, generate a model card and record a replay video of your agent before pushing the repo to the hub
             commit_message="Test commit")

You need to define seven parameters:

  • model: your trained model.

  • model_architecture: name of the architecture of your model (DQN, PPO, A2C, SAC…).

  • env_id: name of the environment.

  • eval_env: environment used to evaluate the agent.

  • repo-id: the name of the Hugging Face repo you want to create or update. It’s <your huggingface username>/<the repo name>.

  • commit-message.

  • filename: the file you want to push to the Hub.

With push_to_hub()

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

from huggingface_sb3 import push_to_hub

# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)

# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)

# Train the agent

# Save the model"ppo-CartPole-v1")

# Push this saved model .zip file to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside"ppo-CartPole-v1")
  commit_message="Added CartPole-v1 model trained with PPO",

You need to define three parameters:

  • repo-id: the name of the Hugging Face repo you want to create or update. It’s <your huggingface username>/<the repo name>.

  • filename: the file you want to push to the Hub.

  • commit-message.


If you want to use MLFLow to track your SB3 experiments, you can adapt the following code which defines a custom logger output:

import sys
from typing import Any, Dict, Tuple, Union

import mlflow
import numpy as np

from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger

class MLflowOutputFormat(KVWriter):
    Dumps key/value pairs into MLflow's numeric format.

    def write(
        key_values: Dict[str, Any],
        key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
        step: int = 0,
    ) -> None:

        for (key, value), (_, excluded) in zip(
            sorted(key_values.items()), sorted(key_excluded.items())

            if excluded is not None and "mlflow" in excluded:

            if isinstance(value, np.ScalarType):
                if not isinstance(value, str):
                    mlflow.log_metric(key, value, step)

loggers = Logger(
    output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],

with mlflow.start_run():
    model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
    # Set custom logger
    model.learn(total_timesteps=10000, log_interval=1)