Source code for evorl.envs.wrappers.training_wrapper

  1from enum import Enum
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7from evorl.utils.jax_utils import rng_split
  8
  9from ..env import Env, EnvState
 10from .wrapper import Wrapper
 11
 12
[docs] 13class EpisodeWrapper(Wrapper): 14 """Maintains episode step count and sets done at episode end. 15 16 This is the same as brax's EpisodeWrapper, and add some new fields in transition.info. 17 Including: 18 - steps: the current step count of the episode 19 - trunction: whether the episode is truncated 20 - termination: whether the episode is terminated 21 - ori_obs: the next observation without autoreset 22 - episode_return: the current sum of dicounted reward of the episode 23 """ 24 25 def __init__( 26 self, 27 env: Env, 28 episode_length: int, 29 record_ori_obs: bool = True, 30 discount: float | None = None, 31 ): 32 """Initializes the env wrapper. 33 34 Args: 35 env: the wrapped env should be a single un-vectorized environment. 36 episode_length: the maxiumum length of each episode for truncation 37 action_repeat: the number of times to repeat each action 38 record_ori_obs: whether to record the real next observation of each episode 39 discount: the discount factor for computing the return. Default is None, which means do not reacord the episode_return. 40 """ 41 super().__init__(env) 42 self.episode_length = episode_length 43 self.record_ori_obs = record_ori_obs 44 self.record_episode_return = discount is not None 45 self.discount = discount 46
[docs] 47 def reset(self, key: chex.PRNGKey) -> EnvState: 48 state = self.env.reset(key) 49 50 info = state.info.replace( 51 steps=jnp.zeros((), dtype=jnp.int32), 52 termination=jnp.zeros(()), 53 truncation=jnp.zeros(()), 54 ) 55 56 if self.record_ori_obs: 57 info.ori_obs = jtu.tree_map(jnp.zeros_like, state.obs) 58 if self.record_episode_return: 59 info.episode_return = jnp.zeros(()) 60 61 return state.replace(info=info)
62
[docs] 63 def step(self, state: EnvState, action: jax.Array) -> EnvState: 64 return self._step(state, action)
65 66 def _step(self, state: EnvState, action: jax.Array) -> EnvState: 67 prev_done = state.done 68 # reset steps when prev episode is done(truncation or termination) 69 steps = state.info.steps * (1 - prev_done).astype(jnp.int32) 70 71 if self.record_episode_return: 72 # reset the episode_return when the episode is done 73 episode_return = state.info.episode_return * (1 - prev_done) 74 75 # ============== pre update ============== 76 state = self.env.step(state, action) 77 78 # ============== post update ============== 79 steps = steps + 1 80 81 done = jnp.where( 82 steps >= self.episode_length, jnp.ones_like(state.done), state.done 83 ) 84 85 info = state.info.replace( 86 steps=steps, 87 termination=state.done, 88 # Note: here we also consider the case: 89 # when termination and truncation are both happened 90 # at the last step, we set truncation=0 91 truncation=jnp.where( 92 steps >= self.episode_length, 1 - state.done, jnp.zeros_like(state.done) 93 ), 94 ) 95 96 if self.record_ori_obs: 97 # the real next_obs at the end of episodes, where 98 # state.obs could be changed to the next episode's inital state 99 # by VmapAutoResetWrapper 100 info.ori_obs = state.obs # i.e. obs at t+1 101 102 if self.record_episode_return: 103 if self.discount == 1.0: # a shortcut for discount=1.0 104 episode_return += state.reward 105 else: 106 episode_return += jnp.power(self.discount, steps - 1) * state.reward 107 info.episode_return = episode_return 108 109 return state.replace(done=done, info=info)
110 111
[docs] 112class OneEpisodeWrapper(EpisodeWrapper): 113 """Maintains episode step count and sets done at episode end. 114 115 When call step() after the env is done, stop simulation and 116 directly return previous state. 117 """ 118
[docs] 119 def step(self, state: EnvState, action: jax.Array) -> EnvState: 120 new_state = self._step(state, action) 121 # Select old state when already done (via jnp.where on leaves). 122 # We avoid lax.cond here because under jax.vmap it traces both 123 # branches, which breaks warp backend's custom_vmap for mjx.step. 124 return jtu.tree_map( 125 lambda old, new: jnp.where(state.done, old, new), 126 state, 127 new_state, 128 )
129 130
[docs] 131class VmapWrapper(Wrapper): 132 """Vectorize env.""" 133 134 def __init__(self, env: Env, num_envs: int = 1, vmap_step: bool = False): 135 """Initialize the env wrapper. 136 137 Args: 138 env: the original env 139 num_envs: number of envs to vectorize 140 vmap_step: whether to vectorize the step function by `vmap`, or use `lax.map` 141 """ 142 super().__init__(env) 143 self.num_envs = num_envs 144 self.vmap_step = vmap_step 145
[docs] 146 def reset(self, key: chex.PRNGKey) -> EnvState: 147 """Reset the vmapped env. 148 149 Args: 150 key: support batched keys [B,2] or single key [2] 151 """ 152 if key.ndim <= 1: 153 key = jax.random.split(key, self.num_envs) 154 else: 155 chex.assert_shape( 156 key, 157 (self.num_envs, 2), 158 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 159 ) 160 161 return jax.vmap(self.env.reset)(key)
162
[docs] 163 def step(self, state: EnvState, action: jax.Array) -> EnvState: 164 if self.vmap_step: 165 return jax.vmap(self.env.step)(state, action) 166 else: 167 return jax.lax.map(lambda x: self.env.step(*x), (state, action))
168 169
[docs] 170class VmapAutoResetWrapper(Wrapper): 171 """Vectorize env and Autoreset.""" 172 173 def __init__(self, env: Env, num_envs: int = 1): 174 """Initialize the env wrapper. 175 176 Args: 177 env: the original env 178 num_envs: number of parallel envs. 179 """ 180 super().__init__(env) 181 self.num_envs = num_envs 182
[docs] 183 def reset(self, key: chex.PRNGKey) -> EnvState: 184 """Reset the vmapped env. 185 186 Args: 187 key: support batched keys [B,2] or single key [2] 188 """ 189 if key.ndim <= 1: 190 key = jax.random.split(key, self.num_envs) 191 else: 192 chex.assert_shape( 193 key, 194 (self.num_envs, 2), 195 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 196 ) 197 198 reset_key, key = rng_split(key) 199 state = jax.vmap(self.env.reset)(key) 200 state._internal.reset_key = reset_key # for autoreset 201 202 return state
203
[docs] 204 def step(self, state: EnvState, action: jax.Array) -> EnvState: 205 state = jax.vmap(self.env.step)(state, action) 206 207 # Map heterogeneous computation (non-parallelizable). 208 # This avoids lax.cond becoming lax.select in vmap 209 state = jax.lax.map(self._maybe_reset, state) 210 211 return state
212 213 def _auto_reset(self, state: EnvState) -> EnvState: 214 """AutoReset the state of one Env. 215 216 Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated. 217 """ 218 # Make sure that the random key in the environment changes at each call to reset. 219 new_key, reset_key = jax.random.split(state._internal.reset_key) 220 reset_state = self.env.reset(reset_key) 221 222 state = state.replace( 223 env_state=reset_state.env_state, 224 obs=reset_state.obs, 225 _internal=state._internal.replace(reset_key=new_key), 226 ) 227 228 return state 229 230 def _maybe_reset(self, state: EnvState) -> EnvState: 231 """Overwrite the state and timestep appropriately if the episode terminates.""" 232 return jax.lax.cond( 233 state.done, 234 self._auto_reset, 235 lambda state: state.replace(), 236 state, 237 )
238 239
[docs] 240class FastVmapAutoResetWrapper(Wrapper): 241 """Brax-style AutoReset: no randomness in reset. 242 243 This wrapper reuses the state in the return of `env.reset()`. When the episodes have short length or the `env.reset()` is expensive, This wrapper is more efficient than `VmapAutoResetWrapper`. 244 """ 245 246 def __init__(self, env: Env, num_envs: int = 1): 247 """Initialize the env wrapper. 248 249 Args: 250 env: the original env 251 num_envs: number of parallel envs. 252 """ 253 super().__init__(env) 254 self.num_envs = num_envs 255
[docs] 256 def reset(self, key: chex.PRNGKey) -> EnvState: 257 """Reset the vmapped env. 258 259 Args: 260 key: support batched keys [B,2] or single key [2] 261 """ 262 if key.ndim <= 1: 263 key = jax.random.split(key, self.num_envs) 264 else: 265 chex.assert_shape( 266 key, 267 (self.num_envs, 2), 268 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 269 ) 270 271 state = jax.vmap(self.env.reset)(key) 272 state._internal.first_env_state = state.env_state 273 state._internal.first_obs = state.obs 274 275 return state
276
[docs] 277 def step(self, state: EnvState, action: jax.Array) -> EnvState: 278 state = jax.vmap(self.env.step)(state, action) 279 280 def where_done(x, y): 281 done = state.done 282 if done.ndim > 0: 283 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1)) 284 return jnp.where(done, x, y) 285 286 env_state = jtu.tree_map( 287 where_done, state._internal.first_env_state, state.env_state 288 ) 289 obs = jtu.tree_map( 290 where_done, state._internal.first_obs, state.obs 291 ) 292 293 return state.replace(env_state=env_state, obs=obs)
294 295
[docs] 296class VmapEnvPoolAutoResetWrapper(Wrapper): 297 """EnvPool style AutoReset. 298 299 When the episode ends, an additional reset step is performed. 300 See EnvPool: https://envpool.readthedocs.io/en/latest/content/python_interface.html#auto-reset, and the Next-Step Mode in gymnasium: https://farama.org/Vector-Autoreset-Mode. 301 This is helpful for algorithms that require n-step TD or GAE with Partial episode bootstrapping (PEB) support on time-limited environments. When using this wrapper, remember to skip the invalid transitions via the mask `autoreset`. 302 """ 303 304 def __init__(self, env: Env, num_envs: int = 1): 305 """Initialize the env wrapper. 306 307 Args: 308 env: the original env 309 num_envs: number of parallel envs. 310 """ 311 super().__init__(env) 312 self.num_envs = num_envs 313
[docs] 314 def reset(self, key: chex.PRNGKey) -> EnvState: 315 """Reset the vmapped env. 316 317 Args: 318 key: support batched keys [B,2] or single key [2] 319 """ 320 if key.ndim <= 1: 321 key = jax.random.split(key, self.num_envs) 322 else: 323 chex.assert_shape( 324 key, 325 (self.num_envs, 2), 326 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 327 ) 328 329 reset_key, key = rng_split(key) 330 state = jax.vmap(self.env.reset)(key) 331 state.info.autoreset = jnp.zeros_like(state.done) # for autoreset flag 332 state._internal.reset_key = reset_key # for autoreset 333 334 return state
335
[docs] 336 def step(self, state: EnvState, action: jax.Array) -> EnvState: 337 autoreset = state.done # i.e. prev_done 338 339 def _where_autoreset(x, y): 340 # where prev_done 341 if autoreset.ndim > 0: 342 cond = jnp.reshape(autoreset, [x.shape[0]] + [1] * (len(x.shape) - 1)) 343 return jnp.where(cond, x, y) 344 345 reset_state = self.reset(state._internal.reset_key) 346 new_state = jax.vmap(self.env.step)(state, action) 347 348 state = jtu.tree_map( 349 _where_autoreset, 350 reset_state, 351 new_state, 352 ) 353 state.info.autoreset = autoreset 354 355 return state
356 357
[docs] 358class AutoresetMode(Enum): 359 """Autoreset mode.""" 360 361 NORMAL = "normal" 362 FAST = "fast" 363 DISABLED = "disabled" 364 ENVPOOL = "envpool"