1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5from brax.envs import (
6 Env as BraxEnv,
7 get_environment,
8)
9
10from evorl.types import Action, PyTreeDict
11
12from .env import Env, EnvAdapter, EnvState
13from .space import Box, Space, SpaceContainer
14from .utils import sort_dict
15from .wrappers.training_wrapper import (
16 AutoresetMode,
17 EpisodeWrapper,
18 FastVmapAutoResetWrapper,
19 OneEpisodeWrapper,
20 VmapAutoResetWrapper,
21 VmapEnvPoolAutoResetWrapper,
22 VmapWrapper,
23)
24
25
[docs]
26class BraxAdapter(EnvAdapter):
27 """Adapter for Brax environments."""
28
29 def __init__(self, env: BraxEnv):
30 super().__init__(env)
31
[docs]
32 def reset(self, key: chex.PRNGKey) -> EnvState:
33 key, reset_key = jax.random.split(key)
34 brax_state = self.env.reset(reset_key)
35
36 info = PyTreeDict(sort_dict(brax_state.info))
37 info.metrics = PyTreeDict(sort_dict(brax_state.metrics))
38
39 return EnvState(
40 env_state=brax_state,
41 obs=brax_state.obs,
42 reward=brax_state.reward,
43 done=brax_state.done,
44 info=info,
45 )
46
[docs]
47 def step(self, state: EnvState, action: Action) -> EnvState:
48 brax_state = self.env.step(state.env_state, action)
49
50 metrics = state.info.metrics.replace(**brax_state.metrics)
51
52 info = state.info.replace(**brax_state.info, metrics=metrics)
53
54 return state.replace(
55 env_state=brax_state,
56 obs=brax_state.obs,
57 reward=brax_state.reward,
58 done=brax_state.done,
59 info=info,
60 )
61
62 @property
63 def action_space(self) -> Space:
64 action_spec = jnp.asarray(self.env.sys.actuator.ctrl_range, dtype=jnp.float32)
65 return Box(low=action_spec[:, 0], high=action_spec[:, 1])
66
67 @property
68 def obs_space(self) -> Space:
69 obs_spec = self.env.observation_size
70
71 def get_space(obs_size):
72 if not isinstance(obs_size, tuple):
73 obs_size = (obs_size,)
74 obs_spec = jnp.full(obs_size, 1e10, dtype=jnp.float32)
75 return Box(low=-obs_spec, high=obs_spec)
76
77 if isinstance(obs_spec, int):
78 return get_space(obs_spec)
79 else:
80 return SpaceContainer(
81 spaces=jtu.tree_map(
82 get_space,
83 obs_spec,
84 is_leaf=lambda obj: isinstance(obj, tuple)
85 and all(isinstance(x, int) for x in obj),
86 )
87 )
88
89
[docs]
90def create_brax_env(env_name: str, **kwargs) -> BraxAdapter:
91 """Create Brax environment.
92
93 Args:
94 env_name: Environment name.
95 kwargs: Arguments passing into Brax.
96
97 Returns:
98 Brax env.
99 """
100 env = get_environment(env_name, **kwargs)
101 env = BraxAdapter(env)
102
103 return env
104
105
[docs]
106def create_wrapped_brax_env(
107 env_name: str,
108 episode_length: int = 1000,
109 parallel: int = 1,
110 autoreset_mode: AutoresetMode = AutoresetMode.NORMAL,
111 discount: float | None = 1.0,
112 record_ori_obs: bool = False,
113 **kwargs,
114) -> Env:
115 """Create wrapped Brax environment for training.
116
117 Args:
118 env_name: Environment name.
119 episode_length: Max episode length.
120 parallel: Number of parallel environments.
121 autoreset_mode: Autoreset mode.
122 discount: Discount factor.
123 record_ori_obs: Whether record original observation in AutoresetMode.NORMAL and AutoresetMode.FAST mode.
124 kwargs: Other arguments passing into Brax.
125
126 Returns:
127 Wrapped Brax env.
128
129 """
130 env = create_brax_env(env_name, **kwargs)
131
132 if autoreset_mode == AutoresetMode.ENVPOOL:
133 # envpool mode will always record last obs
134 record_ori_obs = False
135
136 if autoreset_mode != AutoresetMode.DISABLED:
137 env = EpisodeWrapper(
138 env,
139 episode_length,
140 record_ori_obs=record_ori_obs,
141 discount=discount,
142 )
143 if autoreset_mode == AutoresetMode.FAST:
144 env = FastVmapAutoResetWrapper(env, num_envs=parallel)
145 elif autoreset_mode == AutoresetMode.NORMAL:
146 env = VmapAutoResetWrapper(env, num_envs=parallel)
147 elif autoreset_mode == AutoresetMode.ENVPOOL:
148 env = VmapEnvPoolAutoResetWrapper(env, num_envs=parallel)
149 else:
150 env = OneEpisodeWrapper(env, episode_length, record_ori_obs=record_ori_obs)
151 env = VmapWrapper(env, num_envs=parallel, vmap_step=True)
152
153 return env