Source code for evorl.envs.wrappers.wrapper

 1import chex
 2
 3from evorl.types import Action
 4
 5from ..env import Env, EnvState, Space
 6
 7
[docs] 8class Wrapper(Env): 9 """Wraps an environment to allow modular transformations.""" 10 11 def __init__(self, env: Env): 12 """Initialize the env wrapper. 13 14 Args: 15 env: the orginal env. 16 """ 17 self.env = env 18
[docs] 19 def reset(self, key: chex.PRNGKey) -> EnvState: 20 return self.env.reset(key)
21
[docs] 22 def step(self, state: EnvState, action: Action) -> EnvState: 23 return self.env.step(state.env_state, action)
24 25 @property 26 def obs_space(self) -> Space: 27 return self.env.obs_space 28 29 @property 30 def action_space(self) -> Space: 31 return self.env.action_space 32 33 @property 34 def unwrapped(self) -> Env: 35 if isinstance(self.env, Wrapper) and hasattr(self.env, "unwrapped"): 36 return self.env.unwrapped 37 else: 38 return self.env 39 40 def __getattr__(self, name): 41 if name == "__setstate__": 42 raise AttributeError(name) 43 return getattr(self.env, name)
44 45
[docs] 46def get_wrapper(env: Env, wrapper_cls: type) -> Wrapper | None: 47 """Return a specific wrapper of an env.""" 48 if isinstance(env, wrapper_cls): 49 return env 50 elif hasattr(env, "env"): 51 return get_wrapper(env.env, wrapper_cls) 52 else: 53 return None