from copy import deepcopy
from typing import Dict, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
[docs]class VecTransposeImage(VecEnvWrapper):
"""
Re-order channels, from HxWxC to CxHxW.
It is required for PyTorch convolution layers.
:param venv:
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
which may result in unwanted behavior, see GH issue #671.
"""
def __init__(self, venv: VecEnv, skip: bool = False):
assert is_image_space(venv.observation_space) or isinstance(
venv.observation_space, spaces.Dict
), "The observation space must be an image or dictionary observation space"
self.skip = skip
# Do nothing
if skip:
super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.Dict):
self.image_space_keys = []
observation_space = deepcopy(venv.observation_space)
for key, space in observation_space.spaces.items():
if is_image_space(space):
# Keep track of which keys should be transposed later
self.image_space_keys.append(key)
assert isinstance(space, spaces.Box)
observation_space.spaces[key] = self.transpose_space(space, key)
else:
assert isinstance(venv.observation_space, spaces.Box)
observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment]
super().__init__(venv, observation_space=observation_space)
[docs] @staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
"""
Transpose an observation space (re-order channels).
:param observation_space:
:param key: In case of dictionary space, the key of the observation space.
:return:
"""
# Sanity checks
assert is_image_space(observation_space), "The observation space must be an image"
assert not is_image_space_channels_first(
observation_space
), f"The observation space {key} must follow the channel last convention"
height, width, channels = observation_space.shape
new_shape = (channels, height, width)
return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type]
[docs] @staticmethod
def transpose_image(image: np.ndarray) -> np.ndarray:
"""
Transpose an image or batch of images (re-order channels).
:param image:
:return:
"""
if len(image.shape) == 3:
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
[docs] def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
"""
Transpose (if needed) and return new observations.
:param observations:
:return: Transposed observations
"""
# Do nothing
if self.skip:
return observations
if isinstance(observations, dict):
# Avoid modifying the original object in place
observations = deepcopy(observations)
for k in self.image_space_keys:
observations[k] = self.transpose_image(observations[k])
else:
observations = self.transpose_image(observations)
return observations
[docs] def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
# Transpose the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations), rewards, dones, infos
[docs] def reset(self) -> Union[np.ndarray, Dict]:
"""
Reset all environments
"""
observations = self.venv.reset()
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations)
[docs] def close(self) -> None:
self.venv.close()