Source code for evorl.envs.wrappers.obs_wrapper
1import chex
2import jax
3import jax.numpy as jnp
4
5from evorl.types import Action
6
7from ..env import Env, EnvState
8from ..space import Box, Space
9from .wrapper import Wrapper
10
11
[docs]
12class ObsFlattenWrapper(Wrapper):
13 """Flatten the multi-dimention observation array into a 1D vector."""
14
15 def __init__(self, env: Env):
16 super().__init__(env)
17
18 self.obs_ndim = len(env.obs_space.shape)
19
20 def _flatten_obs(self, state: EnvState) -> EnvState:
21 start_dim = state.obs.ndim - self.obs_ndim
22 state = state.replace(obs=jax.lax.collapse(state.obs, start_dim))
23
24 if "ori_obs" in state.info:
25 state.info.ori_obs = jax.lax.collapse(state.info.ori_obs, start_dim)
26
27 return state
28
[docs]
29 def reset(self, key: chex.PRNGKey) -> EnvState:
30 state = self.env.reset(key)
31 return self._flatten_obs(state)
32
[docs]
33 def step(self, state: EnvState, action: Action) -> EnvState:
34 state = self.env.step(state, action)
35 return self._flatten_obs(state)
36
37 @property
38 def obs_space(self) -> Space:
39 ori_obs_space = self.env.obs_space
40 return Box(low=jnp.ravel(ori_obs_space.low), high=jnp.ravel(ori_obs_space.high))