1import chex
2import jax
3import jax.tree_util as jtu
4from jax import numpy as jnp
5
6from evorl.utils.jax_utils import rng_split
7
8from ..env import Env, EnvState
9from .wrapper import Wrapper
10
11
[docs]
12class EpisodeWrapper(Wrapper):
13 """Multi-Agent version of the [EpisodeWrapper](#training_wrapper.EpisodeWrapper)."""
14
15 def __init__(self, env: Env, episode_length: int):
16 super().__init__(env)
17 self.episode_length = episode_length
18
[docs]
19 def reset(self, key: chex.PRNGKey) -> EnvState:
20 state = self.env.reset(key)
21
22 state.info.steps = jnp.zeros((), dtype=jnp.int32)
23 state.info.termination = jnp.zeros(())
24 state.info.truncation = jnp.zeros(())
25 state.info.ori_obs = jtu.tree_map(lambda x: jnp.zeros_like(x), state.obs)
26
27 return state
28
[docs]
29 def step(self, state: EnvState, action: jax.Array) -> EnvState:
30 return self._step(state, action)
31
32 def _step(self, state: EnvState, action: jax.Array) -> EnvState:
33 state = self.env.step(state, action)
34
35 termination = state.done["__all__"]
36 steps = state.info.steps * (1 - termination).astype(jnp.int32) + 1
37 done = jnp.where(
38 steps >= self.episode_length, jnp.ones_like(termination), termination
39 )
40
41 agents_done = jtu.tree_map(
42 lambda x: jnp.where(done, x, jnp.ones_like(x)), state.done
43 )
44
45 state.info.steps = steps
46 state.info.termination = termination
47 state.info.truncation = jnp.where(
48 steps >= self.episode_length, 1 - termination, jnp.zeros_like(termination)
49 )
50 # the real next_obs at the end of episodes, where
51 # state.obs could be changed in VmapAutoResetWrapper
52 state.info.ori_obs = state.obs
53
54 return state.replace(done=agents_done)
55
56
[docs]
57class OneEpisodeWrapper(EpisodeWrapper):
58 """Multi-Agent version of the [OneEpisodeWrapper](#training_wrapper.OneEpisodeWrapper)."""
59
60 def __init__(self, env: Env, episode_length: int):
61 super().__init__(env, episode_length)
62
[docs]
63 def step(self, state: EnvState, action: jax.Array) -> EnvState:
64 new_state = self._step(state, action)
65 return jtu.tree_map(
66 lambda old, new: jnp.where(state.done["__all__"], old, new),
67 state,
68 new_state,
69 )
70
71
[docs]
72class VmapWrapper(Wrapper):
73 """Multi-Agent version of the [VmapWrapper](#training_wrapper.VmapWrapper)."""
74
75 def __init__(self, env: Env, num_envs: int = 1, vmap_step: bool = False):
76 super().__init__(env)
77 self.num_envs = num_envs
78 self.vmap_step = vmap_step
79
[docs]
80 def reset(self, key: chex.PRNGKey) -> EnvState:
81 if key.ndim <= 1:
82 key = jax.random.split(key, self.num_envs)
83 else:
84 chex.assert_shape(
85 key,
86 (self.num_envs, 2),
87 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
88 )
89
90 return jax.vmap(self.env.reset)(key)
91
[docs]
92 def step(self, state: EnvState, action: jax.Array) -> EnvState:
93 if self.vmap_step:
94 return jax.vmap(self.env.step)(state, action)
95 else:
96 return jax.lax.map(lambda x: self.env.step(*x), (state, action))
97
98
[docs]
99class VmapAutoResetWrapper(Wrapper):
100 """Multi-Agent version of the [VmapAutoResetWrapper](#training_wrapper.VmapAutoResetWrapper)."""
101
102 def __init__(self, env: Env, num_envs: int = 1):
103 super().__init__(env)
104 self.num_envs = num_envs
105
[docs]
106 def reset(self, key: chex.PRNGKey) -> EnvState:
107 if key.ndim <= 1:
108 key = jax.random.split(key, self.num_envs)
109 else:
110 chex.assert_shape(
111 key,
112 (self.num_envs, 2),
113 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
114 )
115
116 reset_key, key = rng_split(key)
117 state = jax.vmap(self.env.reset)(key)
118 state.info.reset_key = reset_key # for autoreset
119
120 return state
121
[docs]
122 def step(self, state: EnvState, action: jax.Array) -> EnvState:
123 state = jax.vmap(self.env.step)(state, action)
124
125 # Map heterogeneous computation (non-parallelizable).
126 # This avoids lax.cond becoming lax.select in vmap
127 state = jax.lax.map(self._maybe_reset, state)
128
129 return state
130
131 def _auto_reset(self, state: EnvState) -> EnvState:
132 # Make sure that the random key in the environment changes at each call to reset.
133 # State is a type variable hence it does not have key type hinted, so we type ignore.
134 new_key, reset_key = jax.random.split(state.info.reset_key)
135 reset_state = self.env.reset(reset_key)
136
137 state = state.replace(
138 env_state=reset_state.env_state,
139 obs=reset_state.obs,
140 )
141
142 state.info.reset_key = new_key
143
144 return state
145
146 def _maybe_reset(self, state: EnvState) -> EnvState:
147 return jax.lax.cond(
148 state.done["__all__"],
149 self._auto_reset,
150 lambda state: state,
151 state,
152 )
153
154
[docs]
155class FastVmapAutoResetWrapper(Wrapper):
156 """Multi-Agent version of the [FastVmapAutoResetWrapper](#training_wrapper.FastVmapAutoResetWrapper)."""
157
158 def __init__(self, env: Env, num_envs: int = 1):
159 super().__init__(env)
160 self.num_envs = num_envs
161
[docs]
162 def reset(self, key: chex.PRNGKey) -> EnvState:
163 if key.ndim <= 1:
164 key = jax.random.split(key, self.num_envs)
165 else:
166 chex.assert_shape(
167 key,
168 (self.num_envs, 2),
169 custom_message=f"Batched key shape {key.shape} must match num_envs: {self.num_envs}",
170 )
171
172 state = jax.vmap(self.env.reset)(key)
173 state.info.first_env_state = state.env_state
174 state.info.first_obs = state.obs
175
176 return state
177
[docs]
178 def step(self, state: EnvState, action: jax.Array) -> EnvState:
179 state = jax.vmap(self.env.step)(state, action)
180
181 def where_done(x, y):
182 done = state.done["__all__"]
183 if done.ndim > 0:
184 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))
185 return jnp.where(done, x, y)
186
187 env_state = jtu.tree_map(
188 where_done, state.info.first_env_state, state.env_state
189 )
190 obs = jtu.tree_map(where_done, state.info.first_obs, state.obs)
191
192 return state.replace(env_state=env_state, obs=obs)