Source code for stable_baselines3.common.logger

import datetime
import json
import os
import sys
import tempfile
import warnings
from collections import defaultdict
from io import TextIOBase
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union

import matplotlib.figure
import numpy as np
import pandas
import torch as th

try:
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.tensorboard.summary import hparams
except ImportError:
    SummaryWriter = None  # type: ignore[misc, assignment]

try:
    from tqdm import tqdm
except ImportError:
    tqdm = None

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50


[docs]class Video: """ Video data class storing the video frames and the frame per seconds :param frames: frames to create the video from :param fps: frames per second """ def __init__(self, frames: th.Tensor, fps: float): self.frames = frames self.fps = fps
[docs]class Figure: """ Figure data class storing a matplotlib figure and whether to close the figure after logging it :param figure: figure to log :param close: if true, close the figure after logging it """ def __init__(self, figure: matplotlib.figure.Figure, close: bool): self.figure = figure self.close = close
[docs]class Image: """ Image data class storing an image and data format :param image: image to log :param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html Gym envs normally use 'HWC' (channel last) """ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str): self.image = image self.dataformats = dataformats
[docs]class HParam: """ Hyperparameter data class storing hyperparameters and metrics in dictionaries :param hparam_dict: key-value pairs of hyperparameters to log :param metric_dict: key-value pairs of metrics to log A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section. """ def __init__(self, hparam_dict: Mapping[str, Union[bool, str, float, None]], metric_dict: Mapping[str, float]): self.hparam_dict = hparam_dict if not metric_dict: raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.") self.metric_dict = metric_dict
[docs]class FormatUnsupportedError(NotImplementedError): """ Custom error to display informative message when a value is not supported by some formats. :param unsupported_formats: A sequence of unsupported formats, for instance ``["stdout"]``. :param value_description: Description of the value that cannot be logged by this format. """ def __init__(self, unsupported_formats: Sequence[str], value_description: str): if len(unsupported_formats) > 1: format_str = f"formats {', '.join(unsupported_formats)} are" else: format_str = f"format {unsupported_formats[0]} is" super().__init__( f"The {format_str} not supported for the {value_description} value logged.\n" f"You can exclude formats via the `exclude` parameter of the logger's `record` function." )
[docs]class KVWriter: """ Key Value writer """
[docs] def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: """ Write a dictionary to file :param key_values: :param key_excluded: :param step: """ raise NotImplementedError
[docs] def close(self) -> None: """ Close owned resources """ raise NotImplementedError
[docs]class SeqWriter: """ sequence writer """
[docs] def write_sequence(self, sequence: List[str]) -> None: """ write_sequence an array to file :param sequence: """ raise NotImplementedError
[docs]class HumanOutputFormat(KVWriter, SeqWriter): """A human-readable output format producing ASCII tables of key-value pairs. Set attribute ``max_length`` to change the maximum length of keys and values to write to output (or specify it when calling ``__init__``). :param filename_or_file: the file to write the log to :param max_length: the maximum length of keys and values to write to output. Outputs longer than this will be truncated. An error will be raised if multiple keys are truncated to the same value. The maximum output width will be ``2*max_length + 7``. The default of 36 produces output no longer than 79 characters wide. """ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): self.max_length = max_length if isinstance(filename_or_file, str): self.file = open(filename_or_file, "w") self.own_file = True elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"): # Note: in theory `TextIOBase` check should be sufficient, # in practice, libraries don't always inherit from it, see GH#1598 self.file = filename_or_file # type: ignore[assignment] self.own_file = False else: raise ValueError(f"Expected file or str, got {filename_or_file}")
[docs] def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: # Create strings for printing key2str = {} tag = "" for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and ("stdout" in excluded or "log" in excluded): continue elif isinstance(value, Video): raise FormatUnsupportedError(["stdout", "log"], "video") elif isinstance(value, Figure): raise FormatUnsupportedError(["stdout", "log"], "figure") elif isinstance(value, Image): raise FormatUnsupportedError(["stdout", "log"], "image") elif isinstance(value, HParam): raise FormatUnsupportedError(["stdout", "log"], "hparam") elif isinstance(value, float): # Align left value_str = f"{value:<8.3g}" else: value_str = str(value) if key.find("/") > 0: # Find tag and add it to the dict tag = key[: key.find("/") + 1] key2str[(tag, self._truncate(tag))] = "" # Remove tag from key and indent the key if len(tag) > 0 and tag in key: key = f"{'':3}{key[len(tag) :]}" truncated_key = self._truncate(key) if (tag, truncated_key) in key2str: raise ValueError( f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." ) key2str[(tag, truncated_key)] = self._truncate(value_str) # Find max widths if len(key2str) == 0: warnings.warn("Tried to write empty key-value dict") return else: tagless_keys = map(lambda x: x[1], key2str.keys()) key_width = max(map(len, tagless_keys)) val_width = max(map(len, key2str.values())) # Write out the data dashes = "-" * (key_width + val_width + 7) lines = [dashes] for (_, key), value in key2str.items(): key_space = " " * (key_width - len(key)) val_space = " " * (val_width - len(value)) lines.append(f"| {key}{key_space} | {value}{val_space} |") lines.append(dashes) if tqdm is not None and hasattr(self.file, "name") and self.file.name == "<stdout>": # Do not mess up with progress bar tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="") else: self.file.write("\n".join(lines) + "\n") # Flush the output to the file self.file.flush()
def _truncate(self, string: str) -> str: if len(string) > self.max_length: string = string[: self.max_length - 3] + "..." return string
[docs] def write_sequence(self, sequence: List[str]) -> None: for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one self.file.write(" ") self.file.write("\n") self.file.flush()
[docs] def close(self) -> None: """ closes the file """ if self.own_file: self.file.close()
[docs]def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]: """ Filters the keys specified by ``key_exclude`` for the specified format :param key_values: log dictionary to be filtered :param key_excluded: keys to be excluded per format :param _format: format for which this filter is run :return: dict without the excluded keys """ def is_excluded(key: str) -> bool: return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key] return {key: value for key, value in key_values.items() if not is_excluded(key)}
[docs]class JSONOutputFormat(KVWriter): """ Log to a file, in the JSON format :param filename: the file to write the log to """ def __init__(self, filename: str): self.file = open(filename, "w")
[docs] def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): if isinstance(value, Video): raise FormatUnsupportedError(["json"], "video") if isinstance(value, Figure): raise FormatUnsupportedError(["json"], "figure") if isinstance(value, Image): raise FormatUnsupportedError(["json"], "image") if isinstance(value, HParam): raise FormatUnsupportedError(["json"], "hparam") if hasattr(value, "dtype"): if value.shape == () or len(value) == 1: # if value is a dimensionless numpy array or of length 1, serialize as a float return float(value.item()) else: # otherwise, a value is a numpy array, serialize as a list or nested lists return value.tolist() return value key_values = { key: cast_to_json_serializable(value) for key, value in filter_excluded_keys(key_values, key_excluded, "json").items() } self.file.write(json.dumps(key_values) + "\n") self.file.flush()
[docs] def close(self) -> None: """ closes the file """ self.file.close()
[docs]class CSVOutputFormat(KVWriter): """ Log to a file, in a CSV format :param filename: the file to write the log to """ def __init__(self, filename: str): self.file = open(filename, "w+t") self.keys: List[str] = [] self.separator = "," self.quotechar = '"'
[docs] def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: # Add our current row to the history key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys if extra_keys: self.keys.extend(extra_keys) self.file.seek(0) lines = self.file.readlines() self.file.seek(0) for i, key in enumerate(self.keys): if i > 0: self.file.write(",") self.file.write(key) self.file.write("\n") for line in lines[1:]: self.file.write(line[:-1]) self.file.write(self.separator * len(extra_keys)) self.file.write("\n") for i, key in enumerate(self.keys): if i > 0: self.file.write(",") value = key_values.get(key) if isinstance(value, Video): raise FormatUnsupportedError(["csv"], "video") elif isinstance(value, Figure): raise FormatUnsupportedError(["csv"], "figure") elif isinstance(value, Image): raise FormatUnsupportedError(["csv"], "image") elif isinstance(value, HParam): raise FormatUnsupportedError(["csv"], "hparam") elif isinstance(value, str): # escape quotechars by prepending them with another quotechar value = value.replace(self.quotechar, self.quotechar + self.quotechar) # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers self.file.write(self.quotechar + value + self.quotechar) elif value is not None: self.file.write(str(value)) self.file.write("\n") self.file.flush()
[docs] def close(self) -> None: """ closes the file """ self.file.close()
[docs]class TensorBoardOutputFormat(KVWriter): """ Dumps key/value pairs into TensorBoard's numeric format. :param folder: the folder to write the log to """ def __init__(self, folder: str): assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so" self.writer = SummaryWriter(log_dir=folder) self._is_closed = False
[docs] def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: assert not self._is_closed, "The SummaryWriter was closed, please re-create one." for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and "tensorboard" in excluded: continue if isinstance(value, np.ScalarType): if isinstance(value, str): # str is considered a np.ScalarType self.writer.add_text(key, value, step) else: self.writer.add_scalar(key, value, step) if isinstance(value, th.Tensor): self.writer.add_histogram(key, value, step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) if isinstance(value, Figure): self.writer.add_figure(key, value.figure, step, close=value.close) if isinstance(value, Image): self.writer.add_image(key, value.image, step, dataformats=value.dataformats) if isinstance(value, HParam): # we don't use `self.writer.add_hparams` to have control over the log_dir experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict) self.writer.file_writer.add_summary(experiment) self.writer.file_writer.add_summary(session_start_info) self.writer.file_writer.add_summary(session_end_info) # Flush the output to the file self.writer.flush()
[docs] def close(self) -> None: """ closes the file """ if self.writer: self.writer.close() self._is_closed = True
[docs]def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter: """ return a logger for the requested format :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard') :param log_dir: the logging directory :param log_suffix: the suffix for the log file :return: the logger """ os.makedirs(log_dir, exist_ok=True) if _format == "stdout": return HumanOutputFormat(sys.stdout) elif _format == "log": return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt")) elif _format == "json": return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json")) elif _format == "csv": return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv")) elif _format == "tensorboard": return TensorBoardOutputFormat(log_dir) else: raise ValueError(f"Unknown format specified: {_format}")
# ================================================================ # Backend # ================================================================
[docs]class Logger: """ The logger class. :param folder: the logging location :param output_formats: the list of output formats """ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration self.name_to_count: Dict[str, int] = defaultdict(int) self.name_to_excluded: Dict[str, Tuple[str, ...]] = {} self.level = INFO self.dir = folder self.output_formats = output_formats
[docs] @staticmethod def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]: """ Helper function to convert str to tuple of str. """ if string_or_tuple is None: return ("",) if isinstance(string_or_tuple, tuple): return string_or_tuple return (string_or_tuple,)
[docs] def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration If called many times, last value will be used. :param key: save to log this key :param value: save to log this value :param exclude: outputs to be excluded """ self.name_to_value[key] = value self.name_to_excluded[key] = self.to_tuple(exclude)
[docs] def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. :param key: save to log this key :param value: save to log this value :param exclude: outputs to be excluded """ if value is None: return old_val, count = self.name_to_value[key], self.name_to_count[key] self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1) self.name_to_count[key] = count + 1 self.name_to_excluded[key] = self.to_tuple(exclude)
[docs] def dump(self, step: int = 0) -> None: """ Write all of the diagnostics from the current iteration """ if self.level == DISABLED: return for _format in self.output_formats: if isinstance(_format, KVWriter): _format.write(self.name_to_value, self.name_to_excluded, step) self.name_to_value.clear() self.name_to_count.clear() self.name_to_excluded.clear()
[docs] def log(self, *args, level: int = INFO) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). level: int. (see logger.py docs) If the global logger level is higher than the level argument here, don't print to stdout. :param args: log the arguments :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ if self.level <= level: self._do_log(args)
[docs] def debug(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the DEBUG level. :param args: log the arguments """ self.log(*args, level=DEBUG)
[docs] def info(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the INFO level. :param args: log the arguments """ self.log(*args, level=INFO)
[docs] def warn(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the WARN level. :param args: log the arguments """ self.log(*args, level=WARN)
[docs] def error(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the ERROR level. :param args: log the arguments """ self.log(*args, level=ERROR)
# Configuration # ----------------------------------------
[docs] def set_level(self, level: int) -> None: """ Set logging threshold on current logger. :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ self.level = level
[docs] def get_dir(self) -> Optional[str]: """ Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) :return: the logging directory """ return self.dir
[docs] def close(self) -> None: """ closes the file """ for _format in self.output_formats: _format.close()
# Misc # ---------------------------------------- def _do_log(self, args: Tuple[Any, ...]) -> None: """ log to the requested format outputs :param args: the arguments to log """ for _format in self.output_formats: if isinstance(_format, SeqWriter): _format.write_sequence(list(map(str, args)))
[docs]def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger: """ Configure the current logger. :param folder: the save location (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) :param format_strings: the output logging format (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) :return: The logger object. """ if folder is None: folder = os.getenv("SB3_LOGDIR") if folder is None: folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) assert isinstance(folder, str) os.makedirs(folder, exist_ok=True) log_suffix = "" if format_strings is None: format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") format_strings = list(filter(None, format_strings)) output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] logger = Logger(folder=folder, output_formats=output_formats) # Only print when some files will be saved if len(format_strings) > 0 and format_strings != ["stdout"]: logger.log(f"Logging to {folder}") return logger
# ================================================================ # Readers # ================================================================
[docs]def read_json(filename: str) -> pandas.DataFrame: """ read a json file using pandas :param filename: the file path to read :return: the data in the json """ data = [] with open(filename) as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data)
[docs]def read_csv(filename: str) -> pandas.DataFrame: """ read a csv file using pandas :param filename: the file path to read :return: the data in the csv """ return pandas.read_csv(filename, index_col=None, comment="#")