Source code for evorl.envs.wrappers.obs_wrapper

 1import chex
 2import jax
 3import jax.numpy as jnp
 4
 5from evorl.types import Action
 6
 7from ..env import Env, EnvState
 8from ..space import Box, Space
 9from .wrapper import Wrapper
10
11
[docs] 12class ObsFlattenWrapper(Wrapper): 13 """Flatten the multi-dimention observation array into a 1D vector.""" 14 15 def __init__(self, env: Env): 16 super().__init__(env) 17 18 self.obs_ndim = len(env.obs_space.shape) 19 20 def _flatten_obs(self, state: EnvState) -> EnvState: 21 start_dim = state.obs.ndim - self.obs_ndim 22 state = state.replace(obs=jax.lax.collapse(state.obs, start_dim)) 23 24 if "ori_obs" in state.info: 25 state.info.ori_obs = jax.lax.collapse(state.info.ori_obs, start_dim) 26 27 return state 28
[docs] 29 def reset(self, key: chex.PRNGKey) -> EnvState: 30 state = self.env.reset(key) 31 return self._flatten_obs(state)
32
[docs] 33 def step(self, state: EnvState, action: Action) -> EnvState: 34 state = self.env.step(state, action) 35 return self._flatten_obs(state)
36 37 @property 38 def obs_space(self) -> Space: 39 ori_obs_space = self.env.obs_space 40 return Box(low=jnp.ravel(ori_obs_space.low), high=jnp.ravel(ori_obs_space.high))