1from functools import partial
2import warnings
3import numpy as np
4import math
5import gymnasium
6import multiprocessing as mp
7
8import chex
9import jax
10import jax.numpy as jnp
11import jax.tree_util as jtu
12
13from evorl.types import Action, PyTreeDict
14
15from .env import Env, EnvAdapter, EnvState
16from .space import Box, Discrete, Space
17from .wrappers import Wrapper, AutoresetMode
18
19
20def _to_jax(pytree):
21 return jtu.tree_map(lambda x: jnp.asarray(x), pytree)
22
23
24def _to_jax_spec(pytree):
25 pytree = _to_jax(pytree)
26 return jtu.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), pytree)
27
28
29def _reshape_batch_dims(pytree, batch_shape):
30 # [B1*...*Bn*#envs, *] -> [B1, ..., Bn, #envs, *]
31 return jtu.tree_map(lambda x: jnp.reshape(x, batch_shape + x.shape[1:]), pytree)
32
33
[docs]
34class GymnasiumAdapter(EnvAdapter):
35 """Adapter for Gymnasium to support Gymnasium environments.
36
37 This env already is a vectorized environment and has experimental supports. It is not recommended to direcly replace other jax-based envs with this env in EvoRL's existing workflows. Users should carefully check the compatibility and modify the corresponding code to avoid the side-effects and other undefined behaviors.
38
39 :::{caution}
40 This env breaks the rule of pure functions. Its env state is maintained inside the gymnasium. Thesefore, users should use it with caution. Unlike other jax-based envs, this env has following limitations:
41
42 - No support for recovering from previous env state.
43 - In other word, you can't rewind after calling `env.step`.
44 - For example, you can't resume the training from a checkpoint exactly as before; Similarly, `evorl.rollout.eval_rollout_episode` will also result in undefined behavior.
45 - We use gymnasium's `AsyncVectorEnv`, which uses python's `multiprocessing` package for parallelism. This may cause performance issues, especially when the number of parallel environments is large.
46 - We recommend that the total number of parallel environments does not exceed the number of CPU logic cores.
47 :::
48 """
49
50 # TODO: multi-device support
51
52 def __init__(
53 self,
54 env_name: str,
55 max_episode_steps: int,
56 num_envs: int,
57 record_ori_obs: bool = False,
58 discount: float | None = None,
59 vecenv_kwargs: dict | None = None,
60 **env_kwargs,
61 ):
62 self.env_name = env_name
63 self.max_episode_steps = max_episode_steps
64 self.num_envs = num_envs
65 self.record_ori_obs = record_ori_obs
66 self.record_episode_return = discount is not None
67 self.discount = discount
68 self.vecenv_kwargs = vecenv_kwargs
69 self.env_specs = env_kwargs
70
71 def _env_fn(num_envs):
72 env = gymnasium.make_vec(
73 self.env_name,
74 num_envs=num_envs,
75 max_episode_steps=self.max_episode_steps,
76 vectorization_mode=gymnasium.VectorizeMode.ASYNC,
77 vector_kwargs=self.vecenv_kwargs,
78 **self.env_specs,
79 )
80
81 return env
82
83 self._env_fn = _env_fn
84 self.env = _env_fn(num_envs)
85 self.autoreset_mode = self.env.metadata["autoreset_mode"]
86
87 self.setup_env_callback()
88
[docs]
89 def setup_env_callback(self):
90 dummy_obs, _ = self.env.reset()
91 # define your own dummy reset info here
92 dummy_reset_info = PyTreeDict()
93 if self.record_ori_obs:
94 dummy_reset_info.ori_obs = dummy_obs
95 reset_spec = _to_jax_spec((dummy_obs, dummy_reset_info))
96
97 dummy_action = self.env.single_action_space.sample()
98 dummy_actions = np.broadcast_to(
99 dummy_action, (self.num_envs,) + dummy_action.shape
100 )
101 # define your own dummy step info here
102 dummy_step_info = PyTreeDict()
103 if self.record_ori_obs:
104 dummy_step_info.ori_obs = dummy_obs
105
106 step_spec = _to_jax_spec(self.env.step(dummy_actions)[:-1] + (dummy_step_info,))
107
108 def _reset(key):
109 batch_shape = key.shape[:-1]
110 num_envs = math.prod(batch_shape) * self.num_envs
111
112 # TODO: reuse the multiprocessing workers from prev self.env,
113 # to avoid creating new processes.
114 self.env = self._env_fn(num_envs)
115
116 assert self.env.num_envs == num_envs
117
118 obs, _info = _reshape_batch_dims(
119 self.env.reset(), batch_shape + (self.num_envs,)
120 )
121 # drop the original info dict as they do not have static shape.
122 info = PyTreeDict()
123 if self.record_ori_obs:
124 info.ori_obs = jnp.zeros_like(obs)
125
126 return obs, info
127
128 def _step(actions):
129 # Note: we are not sure if self.env is always updated by _reset in JIT mode.
130
131 # [B1, ..., Bn, #envs]
132 batch_shape = actions.shape[: -len(self.action_space.shape)]
133
134 # [B1, ..., Bn, #envs, *] -> [B1*...*Bn*#envs, *]
135 actions = jax.lax.collapse(actions, 0, len(batch_shape))
136
137 # [B1*...*Bn*#envs, *]
138 obs, reward, termination, truncation, _info = self.env.step(
139 np.asarray(actions)
140 )
141
142 # drop the original info dict as they do not have static shape.
143 info = PyTreeDict()
144 if self.record_ori_obs:
145 ori_obs = obs.copy()
146 final_obs_list = _info.get("final_obs", None)
147 if final_obs_list is not None:
148 valid_indices = np.array(
149 [o is not None for o in final_obs_list]
150 ).nonzero()[0]
151 ori_obs[valid_indices] = np.stack(
152 [final_obs_list[i] for i in valid_indices]
153 )
154
155 info.ori_obs = ori_obs
156
157 return _reshape_batch_dims(
158 (obs, reward, termination, truncation, info), batch_shape
159 )
160
161 # You are entring the dangerous zone!!!
162 # _reset and _step are not pure functions. Use with caution.
163 self._reset = partial(
164 jax.pure_callback, _reset, reset_spec, vmap_method="expand_dims"
165 )
166 self._step = partial(
167 jax.pure_callback, _step, step_spec, vmap_method="expand_dims"
168 )
169
[docs]
170 def reset(self, key: chex.PRNGKey) -> EnvState:
171 obs, info = _to_jax(self._reset(key))
172
173 info.steps = jnp.zeros((self.num_envs,), dtype=jnp.int32)
174 info.termination = jnp.zeros((self.num_envs,))
175 info.truncation = jnp.zeros((self.num_envs,))
176
177 if self.autoreset_mode == gymnasium.vector.AutoresetMode.NEXT_STEP:
178 info.autoreset = jnp.zeros((self.num_envs,))
179 if self.record_episode_return:
180 info.episode_return = jnp.zeros((self.num_envs,))
181
182 return EnvState(
183 env_state=None,
184 obs=obs,
185 reward=jnp.zeros((self.num_envs,)),
186 done=jnp.zeros((self.num_envs,)),
187 info=info,
188 )
189
[docs]
190 def step(self, state: EnvState, action: Action) -> EnvState:
191 if self.autoreset_mode == gymnasium.vector.AutoresetMode.NEXT_STEP:
192 return self._envpool_autoreset_step(state, action)
193 elif self.autoreset_mode == gymnasium.vector.AutoresetMode.SAME_STEP:
194 return self._normal_autoreset_step(state, action)
195 else:
196 raise NotImplementedError(
197 f"Unsupported autoreset mode: {self.autoreset_mode}"
198 )
199
200 def _envpool_autoreset_step(self, state: EnvState, action: Action) -> EnvState:
201 """Step for Next-Step mode."""
202 autorest = state.done # True = this step is the reset() step
203
204 obs, reward, termination, truncation, info = _to_jax(self._step(action))
205
206 reward = reward.astype(jnp.float32)
207 done = jnp.logical_or(termination, truncation).astype(jnp.float32)
208
209 info.steps = (state.info.steps + 1) * (1 - autorest).astype(jnp.int32)
210 info.termination = termination.astype(jnp.float32)
211 info.truncation = truncation.astype(jnp.float32)
212 info.autoreset = autorest # prev_done
213
214 if self.record_episode_return:
215 episode_return = state.info.episode_return
216 if self.discount == 1.0:
217 episode_return += reward
218 else:
219 episode_return += jnp.power(self.discount, state.info.steps) * reward
220 info.episode_return = episode_return * (1 - autorest)
221
222 return state.replace(obs=obs, reward=reward, done=done, info=info)
223
224 def _normal_autoreset_step(self, state: EnvState, action: Action) -> EnvState:
225 """Step for Same-Step mode."""
226 prev_done = state.done
227
228 steps = state.info.steps * (1 - prev_done).astype(jnp.int32)
229 if self.record_episode_return:
230 episode_return = state.info.episode_return * (1 - prev_done)
231
232 obs, reward, termination, truncation, info = _to_jax(self._step(action))
233 steps = steps + 1
234 reward = reward.astype(jnp.float32)
235 done = jnp.logical_or(termination, truncation).astype(jnp.float32)
236
237 info.steps = steps
238 info.termination = termination.astype(jnp.float32)
239 info.truncation = truncation.astype(jnp.float32)
240 if self.record_episode_return:
241 if self.discount == 1.0:
242 episode_return += reward
243 else:
244 episode_return += jnp.power(self.discount, steps - 1) * reward
245 info.episode_return = episode_return
246
247 return state.replace(obs=obs, reward=reward, done=done, info=info)
248
249 @property
250 def action_space(self) -> Space:
251 return gymnasium_space_to_evorl_space(self.env.single_action_space)
252
253 @property
254 def obs_space(self) -> Space:
255 return gymnasium_space_to_evorl_space(self.env.single_observation_space)
256
257
[docs]
258class OneEpisodeWrapper(Wrapper):
259 """Vectorized one-episode wrapper for evaluation."""
260
261 def __init__(self, env: Env):
262 super().__init__(env)
263
[docs]
264 def step(self, state: EnvState, action: Action) -> EnvState:
265 # Note: could add extra CPU overhead
266
267 def where_done(x, y):
268 done = state.done
269 if done.ndim > 0:
270 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))
271 return jnp.where(done, x, y)
272
273 return jtu.tree_map(
274 where_done,
275 state,
276 self.env.step(state, action),
277 )
278
279
280def _inf_to_num(x, num=1e10):
281 return jnp.nan_to_num(x, posinf=num, neginf=-num)
282
283
[docs]
284def gymnasium_space_to_evorl_space(space: gymnasium.Space) -> Space:
285 if isinstance(space, gymnasium.spaces.Box):
286 low = _inf_to_num(jnp.asarray(space.low))
287 high = _inf_to_num(jnp.asarray(space.high))
288 return Box(low=low, high=high)
289 elif isinstance(space, gymnasium.spaces.Discrete):
290 return Discrete(n=space.n)
291 else:
292 raise NotImplementedError(f"Unsupported space type: {type(space)}")
293
294
[docs]
295def create_gymnasium_env(
296 env_name,
297 episode_length: int = 1000,
298 parallel: int = 1,
299 autoreset_mode: AutoresetMode = AutoresetMode.ENVPOOL,
300 discount: float | None = 1.0,
301 record_ori_obs: bool = False,
302 **kwargs,
303) -> GymnasiumAdapter:
304 """Create a gym env based on Gymnasium.
305
306 Unlike other jax-based env, most wrappers are handled inside the gymnasium.
307 """
308 match autoreset_mode:
309 case AutoresetMode.FAST:
310 warnings.warn(
311 f"{autoreset_mode} is not supported for Gymnasium Envs. Fallback to AutoresetMode.NORMAL.",
312 )
313 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.SAME_STEP
314 case AutoresetMode.NORMAL:
315 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.SAME_STEP
316 case AutoresetMode.ENVPOOL:
317 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.NEXT_STEP
318 if record_ori_obs:
319 warnings.warn(
320 f"{autoreset_mode} does not need record_ori_obs. Fallback to False.",
321 )
322 case AutoresetMode.DISABLED:
323 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.NEXT_STEP
324 discount = None
325
326 mp.get_start_method("spawn")
327 vecenv_kwargs = dict(
328 autoreset_mode=gymnasium_autoreset_mode,
329 context="spawn", # jax's os.fork() warning remains
330 )
331 if "vecenv_kwargs" in kwargs:
332 vecenv_kwargs.update(kwargs.pop("vecenv_kwargs"))
333
334 env = GymnasiumAdapter(
335 env_name=env_name,
336 max_episode_steps=episode_length,
337 num_envs=parallel,
338 record_ori_obs=record_ori_obs,
339 discount=discount,
340 vecenv_kwargs=vecenv_kwargs,
341 **kwargs,
342 )
343
344 if autoreset_mode == AutoresetMode.DISABLED:
345 env = OneEpisodeWrapper(env)
346
347 return env
348
349
350# Note: for env of Humanoid and HumanoidStandup, the action sapce is [-0.4, 0.4], we don't explicitly handle it. You need to manually squash the action space to [-1, 1] by using `ActionSquashWrapper`.