Source code for evorl.envs.gymnasium

  1from functools import partial
  2import warnings
  3import numpy as np
  4import math
  5import gymnasium
  6import multiprocessing as mp
  7
  8import chex
  9import jax
 10import jax.numpy as jnp
 11import jax.tree_util as jtu
 12
 13from evorl.types import Action, PyTreeDict
 14
 15from .env import Env, EnvAdapter, EnvState
 16from .space import Box, Discrete, Space
 17from .wrappers import Wrapper, AutoresetMode
 18
 19
 20def _to_jax(pytree):
 21    return jtu.tree_map(lambda x: jnp.asarray(x), pytree)
 22
 23
 24def _to_jax_spec(pytree):
 25    pytree = _to_jax(pytree)
 26    return jtu.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), pytree)
 27
 28
 29def _reshape_batch_dims(pytree, batch_shape):
 30    # [B1*...*Bn*#envs, *] -> [B1, ..., Bn, #envs, *]
 31    return jtu.tree_map(lambda x: jnp.reshape(x, batch_shape + x.shape[1:]), pytree)
 32
 33
[docs] 34class GymnasiumAdapter(EnvAdapter): 35 """Adapter for Gymnasium to support Gymnasium environments. 36 37 This env already is a vectorized environment and has experimental supports. It is not recommended to direcly replace other jax-based envs with this env in EvoRL's existing workflows. Users should carefully check the compatibility and modify the corresponding code to avoid the side-effects and other undefined behaviors. 38 39 :::{caution} 40 This env breaks the rule of pure functions. Its env state is maintained inside the gymnasium. Thesefore, users should use it with caution. Unlike other jax-based envs, this env has following limitations: 41 42 - No support for recovering from previous env state. 43 - In other word, you can't rewind after calling `env.step`. 44 - For example, you can't resume the training from a checkpoint exactly as before; Similarly, `evorl.rollout.eval_rollout_episode` will also result in undefined behavior. 45 - We use gymnasium's `AsyncVectorEnv`, which uses python's `multiprocessing` package for parallelism. This may cause performance issues, especially when the number of parallel environments is large. 46 - We recommend that the total number of parallel environments does not exceed the number of CPU logic cores. 47 ::: 48 """ 49 50 # TODO: multi-device support 51 52 def __init__( 53 self, 54 env_name: str, 55 max_episode_steps: int, 56 num_envs: int, 57 record_ori_obs: bool = False, 58 discount: float | None = None, 59 vecenv_kwargs: dict | None = None, 60 **env_kwargs, 61 ): 62 self.env_name = env_name 63 self.max_episode_steps = max_episode_steps 64 self.num_envs = num_envs 65 self.record_ori_obs = record_ori_obs 66 self.record_episode_return = discount is not None 67 self.discount = discount 68 self.vecenv_kwargs = vecenv_kwargs 69 self.env_specs = env_kwargs 70 71 def _env_fn(num_envs): 72 env = gymnasium.make_vec( 73 self.env_name, 74 num_envs=num_envs, 75 max_episode_steps=self.max_episode_steps, 76 vectorization_mode=gymnasium.VectorizeMode.ASYNC, 77 vector_kwargs=self.vecenv_kwargs, 78 **self.env_specs, 79 ) 80 81 return env 82 83 self._env_fn = _env_fn 84 self.env = _env_fn(num_envs) 85 self.autoreset_mode = self.env.metadata["autoreset_mode"] 86 87 self.setup_env_callback() 88
[docs] 89 def setup_env_callback(self): 90 dummy_obs, _ = self.env.reset() 91 # define your own dummy reset info here 92 dummy_reset_info = PyTreeDict() 93 if self.record_ori_obs: 94 dummy_reset_info.ori_obs = dummy_obs 95 reset_spec = _to_jax_spec((dummy_obs, dummy_reset_info)) 96 97 dummy_action = self.env.single_action_space.sample() 98 dummy_actions = np.broadcast_to( 99 dummy_action, (self.num_envs,) + dummy_action.shape 100 ) 101 # define your own dummy step info here 102 dummy_step_info = PyTreeDict() 103 if self.record_ori_obs: 104 dummy_step_info.ori_obs = dummy_obs 105 106 step_spec = _to_jax_spec(self.env.step(dummy_actions)[:-1] + (dummy_step_info,)) 107 108 def _reset(key): 109 batch_shape = key.shape[:-1] 110 num_envs = math.prod(batch_shape) * self.num_envs 111 112 # TODO: reuse the multiprocessing workers from prev self.env, 113 # to avoid creating new processes. 114 self.env = self._env_fn(num_envs) 115 116 assert self.env.num_envs == num_envs 117 118 obs, _info = _reshape_batch_dims( 119 self.env.reset(), batch_shape + (self.num_envs,) 120 ) 121 # drop the original info dict as they do not have static shape. 122 info = PyTreeDict() 123 if self.record_ori_obs: 124 info.ori_obs = jnp.zeros_like(obs) 125 126 return obs, info 127 128 def _step(actions): 129 # Note: we are not sure if self.env is always updated by _reset in JIT mode. 130 131 # [B1, ..., Bn, #envs] 132 batch_shape = actions.shape[: -len(self.action_space.shape)] 133 134 # [B1, ..., Bn, #envs, *] -> [B1*...*Bn*#envs, *] 135 actions = jax.lax.collapse(actions, 0, len(batch_shape)) 136 137 # [B1*...*Bn*#envs, *] 138 obs, reward, termination, truncation, _info = self.env.step( 139 np.asarray(actions) 140 ) 141 142 # drop the original info dict as they do not have static shape. 143 info = PyTreeDict() 144 if self.record_ori_obs: 145 ori_obs = obs.copy() 146 final_obs_list = _info.get("final_obs", None) 147 if final_obs_list is not None: 148 valid_indices = np.array( 149 [o is not None for o in final_obs_list] 150 ).nonzero()[0] 151 ori_obs[valid_indices] = np.stack( 152 [final_obs_list[i] for i in valid_indices] 153 ) 154 155 info.ori_obs = ori_obs 156 157 return _reshape_batch_dims( 158 (obs, reward, termination, truncation, info), batch_shape 159 ) 160 161 # You are entring the dangerous zone!!! 162 # _reset and _step are not pure functions. Use with caution. 163 self._reset = partial( 164 jax.pure_callback, _reset, reset_spec, vmap_method="expand_dims" 165 ) 166 self._step = partial( 167 jax.pure_callback, _step, step_spec, vmap_method="expand_dims" 168 )
169
[docs] 170 def reset(self, key: chex.PRNGKey) -> EnvState: 171 obs, info = _to_jax(self._reset(key)) 172 173 info.steps = jnp.zeros((self.num_envs,), dtype=jnp.int32) 174 info.termination = jnp.zeros((self.num_envs,)) 175 info.truncation = jnp.zeros((self.num_envs,)) 176 177 if self.autoreset_mode == gymnasium.vector.AutoresetMode.NEXT_STEP: 178 info.autoreset = jnp.zeros((self.num_envs,)) 179 if self.record_episode_return: 180 info.episode_return = jnp.zeros((self.num_envs,)) 181 182 return EnvState( 183 env_state=None, 184 obs=obs, 185 reward=jnp.zeros((self.num_envs,)), 186 done=jnp.zeros((self.num_envs,)), 187 info=info, 188 )
189
[docs] 190 def step(self, state: EnvState, action: Action) -> EnvState: 191 if self.autoreset_mode == gymnasium.vector.AutoresetMode.NEXT_STEP: 192 return self._envpool_autoreset_step(state, action) 193 elif self.autoreset_mode == gymnasium.vector.AutoresetMode.SAME_STEP: 194 return self._normal_autoreset_step(state, action) 195 else: 196 raise NotImplementedError( 197 f"Unsupported autoreset mode: {self.autoreset_mode}" 198 )
199 200 def _envpool_autoreset_step(self, state: EnvState, action: Action) -> EnvState: 201 """Step for Next-Step mode.""" 202 autorest = state.done # True = this step is the reset() step 203 204 obs, reward, termination, truncation, info = _to_jax(self._step(action)) 205 206 reward = reward.astype(jnp.float32) 207 done = jnp.logical_or(termination, truncation).astype(jnp.float32) 208 209 info.steps = (state.info.steps + 1) * (1 - autorest).astype(jnp.int32) 210 info.termination = termination.astype(jnp.float32) 211 info.truncation = truncation.astype(jnp.float32) 212 info.autoreset = autorest # prev_done 213 214 if self.record_episode_return: 215 episode_return = state.info.episode_return 216 if self.discount == 1.0: 217 episode_return += reward 218 else: 219 episode_return += jnp.power(self.discount, state.info.steps) * reward 220 info.episode_return = episode_return * (1 - autorest) 221 222 return state.replace(obs=obs, reward=reward, done=done, info=info) 223 224 def _normal_autoreset_step(self, state: EnvState, action: Action) -> EnvState: 225 """Step for Same-Step mode.""" 226 prev_done = state.done 227 228 steps = state.info.steps * (1 - prev_done).astype(jnp.int32) 229 if self.record_episode_return: 230 episode_return = state.info.episode_return * (1 - prev_done) 231 232 obs, reward, termination, truncation, info = _to_jax(self._step(action)) 233 steps = steps + 1 234 reward = reward.astype(jnp.float32) 235 done = jnp.logical_or(termination, truncation).astype(jnp.float32) 236 237 info.steps = steps 238 info.termination = termination.astype(jnp.float32) 239 info.truncation = truncation.astype(jnp.float32) 240 if self.record_episode_return: 241 if self.discount == 1.0: 242 episode_return += reward 243 else: 244 episode_return += jnp.power(self.discount, steps - 1) * reward 245 info.episode_return = episode_return 246 247 return state.replace(obs=obs, reward=reward, done=done, info=info) 248 249 @property 250 def action_space(self) -> Space: 251 return gymnasium_space_to_evorl_space(self.env.single_action_space) 252 253 @property 254 def obs_space(self) -> Space: 255 return gymnasium_space_to_evorl_space(self.env.single_observation_space)
256 257
[docs] 258class OneEpisodeWrapper(Wrapper): 259 """Vectorized one-episode wrapper for evaluation.""" 260 261 def __init__(self, env: Env): 262 super().__init__(env) 263
[docs] 264 def step(self, state: EnvState, action: Action) -> EnvState: 265 # Note: could add extra CPU overhead 266 267 def where_done(x, y): 268 done = state.done 269 if done.ndim > 0: 270 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1)) 271 return jnp.where(done, x, y) 272 273 return jtu.tree_map( 274 where_done, 275 state, 276 self.env.step(state, action), 277 )
278 279 280def _inf_to_num(x, num=1e10): 281 return jnp.nan_to_num(x, posinf=num, neginf=-num) 282 283
[docs] 284def gymnasium_space_to_evorl_space(space: gymnasium.Space) -> Space: 285 if isinstance(space, gymnasium.spaces.Box): 286 low = _inf_to_num(jnp.asarray(space.low)) 287 high = _inf_to_num(jnp.asarray(space.high)) 288 return Box(low=low, high=high) 289 elif isinstance(space, gymnasium.spaces.Discrete): 290 return Discrete(n=space.n) 291 else: 292 raise NotImplementedError(f"Unsupported space type: {type(space)}")
293 294
[docs] 295def create_gymnasium_env( 296 env_name, 297 episode_length: int = 1000, 298 parallel: int = 1, 299 autoreset_mode: AutoresetMode = AutoresetMode.ENVPOOL, 300 discount: float | None = 1.0, 301 record_ori_obs: bool = False, 302 **kwargs, 303) -> GymnasiumAdapter: 304 """Create a gym env based on Gymnasium. 305 306 Unlike other jax-based env, most wrappers are handled inside the gymnasium. 307 """ 308 match autoreset_mode: 309 case AutoresetMode.FAST: 310 warnings.warn( 311 f"{autoreset_mode} is not supported for Gymnasium Envs. Fallback to AutoresetMode.NORMAL.", 312 ) 313 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.SAME_STEP 314 case AutoresetMode.NORMAL: 315 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.SAME_STEP 316 case AutoresetMode.ENVPOOL: 317 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.NEXT_STEP 318 if record_ori_obs: 319 warnings.warn( 320 f"{autoreset_mode} does not need record_ori_obs. Fallback to False.", 321 ) 322 case AutoresetMode.DISABLED: 323 gymnasium_autoreset_mode = gymnasium.vector.AutoresetMode.NEXT_STEP 324 discount = None 325 326 mp.get_start_method("spawn") 327 vecenv_kwargs = dict( 328 autoreset_mode=gymnasium_autoreset_mode, 329 context="spawn", # jax's os.fork() warning remains 330 ) 331 if "vecenv_kwargs" in kwargs: 332 vecenv_kwargs.update(kwargs.pop("vecenv_kwargs")) 333 334 env = GymnasiumAdapter( 335 env_name=env_name, 336 max_episode_steps=episode_length, 337 num_envs=parallel, 338 record_ori_obs=record_ori_obs, 339 discount=discount, 340 vecenv_kwargs=vecenv_kwargs, 341 **kwargs, 342 ) 343 344 if autoreset_mode == AutoresetMode.DISABLED: 345 env = OneEpisodeWrapper(env) 346 347 return env
348 349 350# Note: for env of Humanoid and HumanoidStandup, the action sapce is [-0.4, 0.4], we don't explicitly handle it. You need to manually squash the action space to [-1, 1] by using `ActionSquashWrapper`.