Source code for stable_baselines3.common.envs.multi_input_envs

from typing import Dict, Union

import gym
import numpy as np
from gym import spaces

from stable_baselines3.common.type_aliases import GymStepReturn


[docs]class SimpleMultiObsEnv(gym.Env): """ Base class for GridWorld-based MultiObs Environments 4x4 grid world. .. code-block:: text ____________ | 0 1 2 3| | 4|¯5¯¯6¯| 7| | 8|_9_10_|11| |12 13 14 15| ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ start is 0 states 5, 6, 9, and 10 are blocked goal is 15 actions are = [left, down, right, up] simple linear state env of 15 states but encoded with a vector and an image observation: each column is represented by a random vector and each row is represented by a random image, both sampled once at creation time. :param num_col: Number of columns in the grid :param num_row: Number of rows in the grid :param random_start: If true, agent starts in random position :param channel_last: If true, the image will be channel last, else it will be channel first """ def __init__( self, num_col: int = 4, num_row: int = 4, random_start: bool = True, discrete_actions: bool = True, channel_last: bool = True, ): super().__init__() self.vector_size = 5 if channel_last: self.img_size = [64, 64, 1] else: self.img_size = [1, 64, 64] self.random_start = random_start self.discrete_actions = discrete_actions if discrete_actions: self.action_space = spaces.Discrete(4) else: self.action_space = spaces.Box(0, 1, (4,)) self.observation_space = spaces.Dict( spaces={ "vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), "img": spaces.Box(0, 255, self.img_size, dtype=np.uint8), } ) self.count = 0 # Timeout self.max_count = 100 self.log = "" self.state = 0 self.action2str = ["left", "down", "right", "up"] self.init_possible_transitions() self.num_col = num_col self.state_mapping = [] self.init_state_mapping(num_col, num_row) self.max_state = len(self.state_mapping) - 1
[docs] def init_state_mapping(self, num_col: int, num_row: int) -> None: """ Initializes the state_mapping array which holds the observation values for each state :param num_col: Number of columns. :param num_row: Number of rows. """ # Each column is represented by a random vector col_vecs = np.random.random((num_col, self.vector_size)) # Each row is represented by a random image row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8) for i in range(num_col): for j in range(num_row): self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
[docs] def get_state_mapping(self) -> Dict[str, np.ndarray]: """ Uses the state to get the observation mapping. :return: observation dict {'vec': ..., 'img': ...} """ return self.state_mapping[self.state]
[docs] def init_possible_transitions(self) -> None: """ Initializes the transitions of the environment The environment exploits the cardinal directions of the grid by noting that they correspond to simple addition and subtraction from the cell id within the grid - up => means moving up a row => means subtracting the length of a column - down => means moving down a row => means adding the length of a column - left => means moving left by one => means subtracting 1 - right => means moving right by one => means adding 1 Thus one only needs to specify in which states each action is possible in order to define the transitions of the environment """ self.left_possible = [1, 2, 3, 13, 14, 15] self.down_possible = [0, 4, 8, 3, 7, 11] self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15]
[docs] def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. Accepts an action and returns a tuple (observation, reward, done, info). :param action: :return: tuple (observation, reward, done, info). """ if not self.discrete_actions: action = np.argmax(action) else: action = int(action) self.count += 1 prev_state = self.state reward = -0.1 # define state transition if self.state in self.left_possible and action == 0: # left self.state -= 1 elif self.state in self.down_possible and action == 1: # down self.state += self.num_col elif self.state in self.right_possible and action == 2: # right self.state += 1 elif self.state in self.up_possible and action == 3: # up self.state -= self.num_col got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward done = self.count > self.max_count or got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end}
[docs] def render(self, mode: str = "human") -> None: """ Prints the log of the environment. :param mode: """ print(self.log)
[docs] def reset(self) -> Dict[str, np.ndarray]: """ Resets the environment state and step count and returns reset observation. :return: observation dict {'vec': ..., 'img': ...} """ self.count = 0 if not self.random_start: self.state = 0 else: self.state = np.random.randint(0, self.max_state) return self.state_mapping[self.state]