Source code for evorl.envs.brax

  1import chex
  2import jax
  3import jax.numpy as jnp
  4import jax.tree_util as jtu
  5from brax.envs import (
  6    Env as BraxEnv,
  7    get_environment,
  8)
  9
 10from evorl.types import Action, PyTreeDict
 11
 12from .env import Env, EnvAdapter, EnvState
 13from .space import Box, Space, SpaceContainer
 14from .utils import sort_dict
 15from .wrappers.training_wrapper import (
 16    AutoresetMode,
 17    EpisodeWrapper,
 18    FastVmapAutoResetWrapper,
 19    OneEpisodeWrapper,
 20    VmapAutoResetWrapper,
 21    VmapEnvPoolAutoResetWrapper,
 22    VmapWrapper,
 23)
 24
 25
[docs] 26class BraxAdapter(EnvAdapter): 27 """Adapter for Brax environments.""" 28 29 def __init__(self, env: BraxEnv): 30 super().__init__(env) 31
[docs] 32 def reset(self, key: chex.PRNGKey) -> EnvState: 33 key, reset_key = jax.random.split(key) 34 brax_state = self.env.reset(reset_key) 35 36 info = PyTreeDict(sort_dict(brax_state.info)) 37 info.metrics = PyTreeDict(sort_dict(brax_state.metrics)) 38 39 return EnvState( 40 env_state=brax_state, 41 obs=brax_state.obs, 42 reward=brax_state.reward, 43 done=brax_state.done, 44 info=info, 45 )
46
[docs] 47 def step(self, state: EnvState, action: Action) -> EnvState: 48 brax_state = self.env.step(state.env_state, action) 49 50 metrics = state.info.metrics.replace(**brax_state.metrics) 51 52 info = state.info.replace(**brax_state.info, metrics=metrics) 53 54 return state.replace( 55 env_state=brax_state, 56 obs=brax_state.obs, 57 reward=brax_state.reward, 58 done=brax_state.done, 59 info=info, 60 )
61 62 @property 63 def action_space(self) -> Space: 64 action_spec = jnp.asarray(self.env.sys.actuator.ctrl_range, dtype=jnp.float32) 65 return Box(low=action_spec[:, 0], high=action_spec[:, 1]) 66 67 @property 68 def obs_space(self) -> Space: 69 obs_spec = self.env.observation_size 70 71 def get_space(obs_size): 72 if not isinstance(obs_size, tuple): 73 obs_size = (obs_size,) 74 obs_spec = jnp.full(obs_size, 1e10, dtype=jnp.float32) 75 return Box(low=-obs_spec, high=obs_spec) 76 77 if isinstance(obs_spec, int): 78 return get_space(obs_spec) 79 else: 80 return SpaceContainer( 81 spaces=jtu.tree_map( 82 get_space, 83 obs_spec, 84 is_leaf=lambda obj: isinstance(obj, tuple) 85 and all(isinstance(x, int) for x in obj), 86 ) 87 )
88 89
[docs] 90def create_brax_env(env_name: str, **kwargs) -> BraxAdapter: 91 """Create Brax environment. 92 93 Args: 94 env_name: Environment name. 95 kwargs: Arguments passing into Brax. 96 97 Returns: 98 Brax env. 99 """ 100 env = get_environment(env_name, **kwargs) 101 env = BraxAdapter(env) 102 103 return env
104 105
[docs] 106def create_wrapped_brax_env( 107 env_name: str, 108 episode_length: int = 1000, 109 parallel: int = 1, 110 autoreset_mode: AutoresetMode = AutoresetMode.NORMAL, 111 discount: float | None = 1.0, 112 record_ori_obs: bool = False, 113 **kwargs, 114) -> Env: 115 """Create wrapped Brax environment for training. 116 117 Args: 118 env_name: Environment name. 119 episode_length: Max episode length. 120 parallel: Number of parallel environments. 121 autoreset_mode: Autoreset mode. 122 discount: Discount factor. 123 record_ori_obs: Whether record original observation in AutoresetMode.NORMAL and AutoresetMode.FAST mode. 124 kwargs: Other arguments passing into Brax. 125 126 Returns: 127 Wrapped Brax env. 128 129 """ 130 env = create_brax_env(env_name, **kwargs) 131 132 if autoreset_mode == AutoresetMode.ENVPOOL: 133 # envpool mode will always record last obs 134 record_ori_obs = False 135 136 if autoreset_mode != AutoresetMode.DISABLED: 137 env = EpisodeWrapper( 138 env, 139 episode_length, 140 record_ori_obs=record_ori_obs, 141 discount=discount, 142 ) 143 if autoreset_mode == AutoresetMode.FAST: 144 env = FastVmapAutoResetWrapper(env, num_envs=parallel) 145 elif autoreset_mode == AutoresetMode.NORMAL: 146 env = VmapAutoResetWrapper(env, num_envs=parallel) 147 elif autoreset_mode == AutoresetMode.ENVPOOL: 148 env = VmapEnvPoolAutoResetWrapper(env, num_envs=parallel) 149 else: 150 env = OneEpisodeWrapper(env, episode_length, record_ori_obs=record_ori_obs) 151 env = VmapWrapper(env, num_envs=parallel, vmap_step=True) 152 153 return env