Source code for evorl.envs.wrappers.ma_training_wrapper

  1import chex
  2import jax
  3import jax.tree_util as jtu
  4from jax import numpy as jnp
  5
  6from evorl.utils.jax_utils import rng_split
  7
  8from ..env import Env, EnvState
  9from .wrapper import Wrapper
 10
 11
[docs] 12class EpisodeWrapper(Wrapper): 13 """Multi-Agent version of the [EpisodeWrapper](#training_wrapper.EpisodeWrapper).""" 14 15 def __init__(self, env: Env, episode_length: int): 16 super().__init__(env) 17 self.episode_length = episode_length 18
[docs] 19 def reset(self, key: chex.PRNGKey) -> EnvState: 20 state = self.env.reset(key) 21 22 state.info.steps = jnp.zeros((), dtype=jnp.int32) 23 state.info.termination = jnp.zeros(()) 24 state.info.truncation = jnp.zeros(()) 25 state.info.ori_obs = jtu.tree_map(lambda x: jnp.zeros_like(x), state.obs) 26 27 return state
28
[docs] 29 def step(self, state: EnvState, action: jax.Array) -> EnvState: 30 return self._step(state, action)
31 32 def _step(self, state: EnvState, action: jax.Array) -> EnvState: 33 state = self.env.step(state, action) 34 35 termination = state.done["__all__"] 36 steps = state.info.steps * (1 - termination).astype(jnp.int32) + 1 37 done = jnp.where( 38 steps >= self.episode_length, jnp.ones_like(termination), termination 39 ) 40 41 agents_done = jtu.tree_map( 42 lambda x: jnp.where(done, x, jnp.ones_like(x)), state.done 43 ) 44 45 state.info.steps = steps 46 state.info.termination = termination 47 state.info.truncation = jnp.where( 48 steps >= self.episode_length, 1 - termination, jnp.zeros_like(termination) 49 ) 50 # the real next_obs at the end of episodes, where 51 # state.obs could be changed in VmapAutoResetWrapper 52 state.info.ori_obs = state.obs 53 54 return state.replace(done=agents_done)
55 56
[docs] 57class OneEpisodeWrapper(EpisodeWrapper): 58 """Multi-Agent version of the [OneEpisodeWrapper](#training_wrapper.OneEpisodeWrapper).""" 59 60 def __init__(self, env: Env, episode_length: int): 61 super().__init__(env, episode_length) 62
[docs] 63 def step(self, state: EnvState, action: jax.Array) -> EnvState: 64 new_state = self._step(state, action) 65 return jtu.tree_map( 66 lambda old, new: jnp.where(state.done["__all__"], old, new), 67 state, 68 new_state, 69 )
70 71
[docs] 72class VmapWrapper(Wrapper): 73 """Multi-Agent version of the [VmapWrapper](#training_wrapper.VmapWrapper).""" 74 75 def __init__(self, env: Env, num_envs: int = 1, vmap_step: bool = False): 76 super().__init__(env) 77 self.num_envs = num_envs 78 self.vmap_step = vmap_step 79
[docs] 80 def reset(self, key: chex.PRNGKey) -> EnvState: 81 if key.ndim <= 1: 82 key = jax.random.split(key, self.num_envs) 83 else: 84 chex.assert_shape( 85 key, 86 (self.num_envs, 2), 87 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 88 ) 89 90 return jax.vmap(self.env.reset)(key)
91
[docs] 92 def step(self, state: EnvState, action: jax.Array) -> EnvState: 93 if self.vmap_step: 94 return jax.vmap(self.env.step)(state, action) 95 else: 96 return jax.lax.map(lambda x: self.env.step(*x), (state, action))
97 98
[docs] 99class VmapAutoResetWrapper(Wrapper): 100 """Multi-Agent version of the [VmapAutoResetWrapper](#training_wrapper.VmapAutoResetWrapper).""" 101 102 def __init__(self, env: Env, num_envs: int = 1): 103 super().__init__(env) 104 self.num_envs = num_envs 105
[docs] 106 def reset(self, key: chex.PRNGKey) -> EnvState: 107 if key.ndim <= 1: 108 key = jax.random.split(key, self.num_envs) 109 else: 110 chex.assert_shape( 111 key, 112 (self.num_envs, 2), 113 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 114 ) 115 116 reset_key, key = rng_split(key) 117 state = jax.vmap(self.env.reset)(key) 118 state.info.reset_key = reset_key # for autoreset 119 120 return state
121
[docs] 122 def step(self, state: EnvState, action: jax.Array) -> EnvState: 123 state = jax.vmap(self.env.step)(state, action) 124 125 # Map heterogeneous computation (non-parallelizable). 126 # This avoids lax.cond becoming lax.select in vmap 127 state = jax.lax.map(self._maybe_reset, state) 128 129 return state
130 131 def _auto_reset(self, state: EnvState) -> EnvState: 132 # Make sure that the random key in the environment changes at each call to reset. 133 # State is a type variable hence it does not have key type hinted, so we type ignore. 134 new_key, reset_key = jax.random.split(state.info.reset_key) 135 reset_state = self.env.reset(reset_key) 136 137 state = state.replace( 138 env_state=reset_state.env_state, 139 obs=reset_state.obs, 140 ) 141 142 state.info.reset_key = new_key 143 144 return state 145 146 def _maybe_reset(self, state: EnvState) -> EnvState: 147 return jax.lax.cond( 148 state.done["__all__"], 149 self._auto_reset, 150 lambda state: state, 151 state, 152 )
153 154
[docs] 155class FastVmapAutoResetWrapper(Wrapper): 156 """Multi-Agent version of the [FastVmapAutoResetWrapper](#training_wrapper.FastVmapAutoResetWrapper).""" 157 158 def __init__(self, env: Env, num_envs: int = 1): 159 super().__init__(env) 160 self.num_envs = num_envs 161
[docs] 162 def reset(self, key: chex.PRNGKey) -> EnvState: 163 if key.ndim <= 1: 164 key = jax.random.split(key, self.num_envs) 165 else: 166 chex.assert_shape( 167 key, 168 (self.num_envs, 2), 169 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}", 170 ) 171 172 state = jax.vmap(self.env.reset)(key) 173 state.info.first_env_state = state.env_state 174 state.info.first_obs = state.obs 175 176 return state
177
[docs] 178 def step(self, state: EnvState, action: jax.Array) -> EnvState: 179 state = jax.vmap(self.env.step)(state, action) 180 181 def where_done(x, y): 182 done = state.done["__all__"] 183 if done.ndim > 0: 184 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1)) 185 return jnp.where(done, x, y) 186 187 env_state = jtu.tree_map( 188 where_done, state.info.first_env_state, state.env_state 189 ) 190 obs = jtu.tree_map(where_done, state.info.first_obs, state.obs) 191 192 return state.replace(env_state=env_state, obs=obs)