Source code for stable_baselines3.common.evaluation

import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gym
import numpy as np

from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv


[docs]def evaluate_policy( model: "base_class.BaseAlgorithm", env: Union[gym.Env, VecEnv], n_eval_episodes: int = 10, deterministic: bool = True, render: bool = False, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, reward_threshold: Optional[float] = None, return_episode_rewards: bool = False, warn: bool = True, ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. This is made to work only with one env. .. note:: If environment has not been wrapped with ``Monitor`` wrapper, reward and episode lengths are counted as it appears with ``env.step`` calls. If the environment contains wrappers that modify rewards or episode lengths (e.g. reward scaling, early episode reset), these will affect the evaluation results as well. You can avoid this by wrapping environment with ``Monitor`` wrapper before anything else. :param model: The RL agent you want to evaluate. :param env: The gym environment. In the case of a ``VecEnv`` this must contain only one environment. :param n_eval_episodes: Number of episode to evaluate the agent :param deterministic: Whether to use deterministic or stochastic actions :param render: Whether to render the environment or not :param callback: callback function to do additional checks, called after each step. Gets locals() and globals() passed as parameters. :param reward_threshold: Minimum expected reward per episode, this will raise an error if the performance is not met :param return_episode_rewards: If True, a list of rewards and episode lengths per episode will be returned instead of the mean. :param warn: If True (default), warns user about lack of a Monitor wrapper in the evaluation environment. :return: Mean reward per episode, std of reward per episode. Returns ([float], [int]) when ``return_episode_rewards`` is True, first list containing per-episode rewards and second containing per-episode lengths (in number of steps). """ is_monitor_wrapped = False # Avoid circular import from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.monitor import Monitor if isinstance(env, VecEnv): assert env.num_envs == 1, "You must pass only one environment when using this function" is_monitor_wrapped = env.env_is_wrapped(Monitor)[0] else: is_monitor_wrapped = is_wrapped(env, Monitor) if not is_monitor_wrapped and warn: warnings.warn( "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " "Consider wrapping environment first with ``Monitor`` wrapper.", UserWarning, ) episode_rewards, episode_lengths = [], [] not_reseted = True while len(episode_rewards) < n_eval_episodes: # Number of loops here might differ from true episodes # played, if underlying wrappers modify episode lengths. # Avoid double reset, as VecEnv are reset automatically. if not isinstance(env, VecEnv) or not_reseted: obs = env.reset() not_reseted = False done, state = False, None episode_reward = 0.0 episode_length = 0 while not done: action, state = model.predict(obs, state=state, deterministic=deterministic) obs, reward, done, info = env.step(action) episode_reward += reward if callback is not None: callback(locals(), globals()) episode_length += 1 if render: env.render() if is_monitor_wrapped: # Do not trust "done" with episode endings. # Remove vecenv stacking (if any) if isinstance(env, VecEnv): info = info[0] if "episode" in info.keys(): # Monitor wrapper includes "episode" key in info if environment # has been wrapped with it. Use those rewards instead. episode_rewards.append(info["episode"]["r"]) episode_lengths.append(info["episode"]["l"]) else: episode_rewards.append(episode_reward) episode_lengths.append(episode_length) mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" if return_episode_rewards: return episode_rewards, episode_lengths return mean_reward, std_reward