Source code for stable_baselines3.common.vec_env.vec_extract_dict_obs
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
[docs]class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
[docs] def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info