Source code for evorl.envs.wrappers.wrapper
1import chex
2
3from evorl.types import Action
4
5from ..env import Env, EnvState, Space
6
7
[docs]
8class Wrapper(Env):
9 """Wraps an environment to allow modular transformations."""
10
11 def __init__(self, env: Env):
12 """Initialize the env wrapper.
13
14 Args:
15 env: the orginal env.
16 """
17 self.env = env
18
[docs]
19 def reset(self, key: chex.PRNGKey) -> EnvState:
20 return self.env.reset(key)
21
[docs]
22 def step(self, state: EnvState, action: Action) -> EnvState:
23 return self.env.step(state.env_state, action)
24
25 @property
26 def obs_space(self) -> Space:
27 return self.env.obs_space
28
29 @property
30 def action_space(self) -> Space:
31 return self.env.action_space
32
33 @property
34 def unwrapped(self) -> Env:
35 if isinstance(self.env, Wrapper) and hasattr(self.env, "unwrapped"):
36 return self.env.unwrapped
37 else:
38 return self.env
39
40 def __getattr__(self, name):
41 if name == "__setstate__":
42 raise AttributeError(name)
43 return getattr(self.env, name)
44
45
[docs]
46def get_wrapper(env: Env, wrapper_cls: type) -> Wrapper | None:
47 """Return a specific wrapper of an env."""
48 if isinstance(env, wrapper_cls):
49 return env
50 elif hasattr(env, "env"):
51 return get_wrapper(env.env, wrapper_cls)
52 else:
53 return None