1from enum import Enum
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7from evorl.utils.jax_utils import rng_split
8
9from ..env import Env, EnvState
10from .wrapper import Wrapper
11
12
[docs]
13class EpisodeWrapper(Wrapper):
14 """Maintains episode step count and sets done at episode end.
15
16 This is the same as brax's EpisodeWrapper, and add some new fields in transition.info.
17 Including:
18 - steps: the current step count of the episode
19 - trunction: whether the episode is truncated
20 - termination: whether the episode is terminated
21 - ori_obs: the next observation without autoreset
22 - episode_return: the current sum of dicounted reward of the episode
23 """
24
25 def __init__(
26 self,
27 env: Env,
28 episode_length: int,
29 record_ori_obs: bool = True,
30 discount: float | None = None,
31 ):
32 """Initializes the env wrapper.
33
34 Args:
35 env: the wrapped env should be a single un-vectorized environment.
36 episode_length: the maxiumum length of each episode for truncation
37 action_repeat: the number of times to repeat each action
38 record_ori_obs: whether to record the real next observation of each episode
39 discount: the discount factor for computing the return. Default is None, which means do not reacord the episode_return.
40 """
41 super().__init__(env)
42 self.episode_length = episode_length
43 self.record_ori_obs = record_ori_obs
44 self.record_episode_return = discount is not None
45 self.discount = discount
46
[docs]
47 def reset(self, key: chex.PRNGKey) -> EnvState:
48 state = self.env.reset(key)
49
50 info = state.info.replace(
51 steps=jnp.zeros((), dtype=jnp.int32),
52 termination=jnp.zeros(()),
53 truncation=jnp.zeros(()),
54 )
55
56 if self.record_ori_obs:
57 info.ori_obs = jtu.tree_map(jnp.zeros_like, state.obs)
58 if self.record_episode_return:
59 info.episode_return = jnp.zeros(())
60
61 return state.replace(info=info)
62
[docs]
63 def step(self, state: EnvState, action: jax.Array) -> EnvState:
64 return self._step(state, action)
65
66 def _step(self, state: EnvState, action: jax.Array) -> EnvState:
67 prev_done = state.done
68 # reset steps when prev episode is done(truncation or termination)
69 steps = state.info.steps * (1 - prev_done).astype(jnp.int32)
70
71 if self.record_episode_return:
72 # reset the episode_return when the episode is done
73 episode_return = state.info.episode_return * (1 - prev_done)
74
75 # ============== pre update ==============
76 state = self.env.step(state, action)
77
78 # ============== post update ==============
79 steps = steps + 1
80
81 done = jnp.where(
82 steps >= self.episode_length, jnp.ones_like(state.done), state.done
83 )
84
85 info = state.info.replace(
86 steps=steps,
87 termination=state.done,
88 # Note: here we also consider the case:
89 # when termination and truncation are both happened
90 # at the last step, we set truncation=0
91 truncation=jnp.where(
92 steps >= self.episode_length, 1 - state.done, jnp.zeros_like(state.done)
93 ),
94 )
95
96 if self.record_ori_obs:
97 # the real next_obs at the end of episodes, where
98 # state.obs could be changed to the next episode's inital state
99 # by VmapAutoResetWrapper
100 info.ori_obs = state.obs # i.e. obs at t+1
101
102 if self.record_episode_return:
103 if self.discount == 1.0: # a shortcut for discount=1.0
104 episode_return += state.reward
105 else:
106 episode_return += jnp.power(self.discount, steps - 1) * state.reward
107 info.episode_return = episode_return
108
109 return state.replace(done=done, info=info)
110
111
[docs]
112class OneEpisodeWrapper(EpisodeWrapper):
113 """Maintains episode step count and sets done at episode end.
114
115 When call step() after the env is done, stop simulation and
116 directly return previous state.
117 """
118
[docs]
119 def step(self, state: EnvState, action: jax.Array) -> EnvState:
120 new_state = self._step(state, action)
121 # Select old state when already done (via jnp.where on leaves).
122 # We avoid lax.cond here because under jax.vmap it traces both
123 # branches, which breaks warp backend's custom_vmap for mjx.step.
124 return jtu.tree_map(
125 lambda old, new: jnp.where(state.done, old, new),
126 state,
127 new_state,
128 )
129
130
[docs]
131class VmapWrapper(Wrapper):
132 """Vectorize env."""
133
134 def __init__(self, env: Env, num_envs: int = 1, vmap_step: bool = False):
135 """Initialize the env wrapper.
136
137 Args:
138 env: the original env
139 num_envs: number of envs to vectorize
140 vmap_step: whether to vectorize the step function by `vmap`, or use `lax.map`
141 """
142 super().__init__(env)
143 self.num_envs = num_envs
144 self.vmap_step = vmap_step
145
[docs]
146 def reset(self, key: chex.PRNGKey) -> EnvState:
147 """Reset the vmapped env.
148
149 Args:
150 key: support batched keys [B,2] or single key [2]
151 """
152 if key.ndim <= 1:
153 key = jax.random.split(key, self.num_envs)
154 else:
155 chex.assert_shape(
156 key,
157 (self.num_envs, 2),
158 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
159 )
160
161 return jax.vmap(self.env.reset)(key)
162
[docs]
163 def step(self, state: EnvState, action: jax.Array) -> EnvState:
164 if self.vmap_step:
165 return jax.vmap(self.env.step)(state, action)
166 else:
167 return jax.lax.map(lambda x: self.env.step(*x), (state, action))
168
169
[docs]
170class VmapAutoResetWrapper(Wrapper):
171 """Vectorize env and Autoreset."""
172
173 def __init__(self, env: Env, num_envs: int = 1):
174 """Initialize the env wrapper.
175
176 Args:
177 env: the original env
178 num_envs: number of parallel envs.
179 """
180 super().__init__(env)
181 self.num_envs = num_envs
182
[docs]
183 def reset(self, key: chex.PRNGKey) -> EnvState:
184 """Reset the vmapped env.
185
186 Args:
187 key: support batched keys [B,2] or single key [2]
188 """
189 if key.ndim <= 1:
190 key = jax.random.split(key, self.num_envs)
191 else:
192 chex.assert_shape(
193 key,
194 (self.num_envs, 2),
195 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
196 )
197
198 reset_key, key = rng_split(key)
199 state = jax.vmap(self.env.reset)(key)
200 state._internal.reset_key = reset_key # for autoreset
201
202 return state
203
[docs]
204 def step(self, state: EnvState, action: jax.Array) -> EnvState:
205 state = jax.vmap(self.env.step)(state, action)
206
207 # Map heterogeneous computation (non-parallelizable).
208 # This avoids lax.cond becoming lax.select in vmap
209 state = jax.lax.map(self._maybe_reset, state)
210
211 return state
212
213 def _auto_reset(self, state: EnvState) -> EnvState:
214 """AutoReset the state of one Env.
215
216 Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated.
217 """
218 # Make sure that the random key in the environment changes at each call to reset.
219 new_key, reset_key = jax.random.split(state._internal.reset_key)
220 reset_state = self.env.reset(reset_key)
221
222 state = state.replace(
223 env_state=reset_state.env_state,
224 obs=reset_state.obs,
225 _internal=state._internal.replace(reset_key=new_key),
226 )
227
228 return state
229
230 def _maybe_reset(self, state: EnvState) -> EnvState:
231 """Overwrite the state and timestep appropriately if the episode terminates."""
232 return jax.lax.cond(
233 state.done,
234 self._auto_reset,
235 lambda state: state.replace(),
236 state,
237 )
238
239
[docs]
240class FastVmapAutoResetWrapper(Wrapper):
241 """Brax-style AutoReset: no randomness in reset.
242
243 This wrapper reuses the state in the return of `env.reset()`. When the episodes have short length or the `env.reset()` is expensive, This wrapper is more efficient than `VmapAutoResetWrapper`.
244 """
245
246 def __init__(self, env: Env, num_envs: int = 1):
247 """Initialize the env wrapper.
248
249 Args:
250 env: the original env
251 num_envs: number of parallel envs.
252 """
253 super().__init__(env)
254 self.num_envs = num_envs
255
[docs]
256 def reset(self, key: chex.PRNGKey) -> EnvState:
257 """Reset the vmapped env.
258
259 Args:
260 key: support batched keys [B,2] or single key [2]
261 """
262 if key.ndim <= 1:
263 key = jax.random.split(key, self.num_envs)
264 else:
265 chex.assert_shape(
266 key,
267 (self.num_envs, 2),
268 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
269 )
270
271 state = jax.vmap(self.env.reset)(key)
272 state._internal.first_env_state = state.env_state
273 state._internal.first_obs = state.obs
274
275 return state
276
[docs]
277 def step(self, state: EnvState, action: jax.Array) -> EnvState:
278 state = jax.vmap(self.env.step)(state, action)
279
280 def where_done(x, y):
281 done = state.done
282 if done.ndim > 0:
283 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))
284 return jnp.where(done, x, y)
285
286 env_state = jtu.tree_map(
287 where_done, state._internal.first_env_state, state.env_state
288 )
289 obs = jtu.tree_map(
290 where_done, state._internal.first_obs, state.obs
291 )
292
293 return state.replace(env_state=env_state, obs=obs)
294
295
[docs]
296class VmapEnvPoolAutoResetWrapper(Wrapper):
297 """EnvPool style AutoReset.
298
299 When the episode ends, an additional reset step is performed.
300 See EnvPool: https://envpool.readthedocs.io/en/latest/content/python_interface.html#auto-reset, and the Next-Step Mode in gymnasium: https://farama.org/Vector-Autoreset-Mode.
301 This is helpful for algorithms that require n-step TD or GAE with Partial episode bootstrapping (PEB) support on time-limited environments. When using this wrapper, remember to skip the invalid transitions via the mask `autoreset`.
302 """
303
304 def __init__(self, env: Env, num_envs: int = 1):
305 """Initialize the env wrapper.
306
307 Args:
308 env: the original env
309 num_envs: number of parallel envs.
310 """
311 super().__init__(env)
312 self.num_envs = num_envs
313
[docs]
314 def reset(self, key: chex.PRNGKey) -> EnvState:
315 """Reset the vmapped env.
316
317 Args:
318 key: support batched keys [B,2] or single key [2]
319 """
320 if key.ndim <= 1:
321 key = jax.random.split(key, self.num_envs)
322 else:
323 chex.assert_shape(
324 key,
325 (self.num_envs, 2),
326 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
327 )
328
329 reset_key, key = rng_split(key)
330 state = jax.vmap(self.env.reset)(key)
331 state.info.autoreset = jnp.zeros_like(state.done) # for autoreset flag
332 state._internal.reset_key = reset_key # for autoreset
333
334 return state
335
[docs]
336 def step(self, state: EnvState, action: jax.Array) -> EnvState:
337 autoreset = state.done # i.e. prev_done
338
339 def _where_autoreset(x, y):
340 # where prev_done
341 if autoreset.ndim > 0:
342 cond = jnp.reshape(autoreset, [x.shape[0]] + [1] * (len(x.shape) - 1))
343 return jnp.where(cond, x, y)
344
345 reset_state = self.reset(state._internal.reset_key)
346 new_state = jax.vmap(self.env.step)(state, action)
347
348 state = jtu.tree_map(
349 _where_autoreset,
350 reset_state,
351 new_state,
352 )
353 state.info.autoreset = autoreset
354
355 return state
356
357
[docs]
358class AutoresetMode(Enum):
359 """Autoreset mode."""
360
361 NORMAL = "normal"
362 FAST = "fast"
363 DISABLED = "disabled"
364 ENVPOOL = "envpool"