Source code for evorl.envs.wrappers.action_wrapper

 1import jax
 2import jax.numpy as jnp
 3import jax.tree_util as jtu
 4
 5from evorl.types import Action
 6
 7from ..env import Env, EnvState
 8from ..space import Box, Space
 9from .wrapper import Wrapper
10
11
[docs] 12class ActionSquashWrapper(Wrapper): 13 """Convert continuous action space from [-1, 1] to [low, high].""" 14 15 def __init__(self, env: Env): 16 super().__init__(env) 17 18 # TODO: support pytree action space 19 action_space = self.env.action_space 20 assert isinstance(action_space, Box), "Only support Box action_space" 21 22 self.scale = (action_space.high - action_space.low) * 0.5 23 self.bias = (action_space.high + action_space.low) * 0.5 24
[docs] 25 def step(self, state: EnvState, action: Action) -> EnvState: 26 squashed_action = self.scale * action + self.bias 27 return self.env.step(state, squashed_action)
28 29 @property 30 def action_space(self) -> Space: 31 return Box(low=-jnp.ones_like(self.scale), high=jnp.ones_like(self.scale))
32 33
[docs] 34class ActionRepeatWrapper(Wrapper): 35 """Repeat action for a number of steps. 36 37 :::{note} 38 This wrapper only accumulates `state.reward` and `state.info.ori_reward`. It is safe to use `ActionRepeatWrapper(RewardScaleWrapper(EpisodeWrapper(env)))`. However, if you want accumulate other metrics, inherit this class and add your own logic. 39 ::: 40 :::{caution} 41 When using rollout functions like `rollout`, `eval_rollout_episode` with `rollout_length` argument, users should use `math.ceil(env.max_episode_steps/action_repeat)` to match the real rollout_length. 42 ::: 43 """ 44 45 def __init__(self, env: Env, action_repeat: int): 46 super().__init__(env) 47 48 self.action_repeat = action_repeat 49
[docs] 50 def step(self, state: EnvState, action: Action) -> EnvState: 51 def f(state, _): 52 nstate = self.env.step(state, action) 53 54 def where_done(x, y): 55 done = state.done # prev_done 56 if done.ndim > 0: 57 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1)) 58 return jnp.where(done, x, y) 59 60 # when prev_done=True, keep the previous state and set reward=0 61 nstate = jtu.tree_map(where_done, state, nstate) 62 63 reward = nstate.reward 64 reward = jtu.tree_map(where_done, jnp.zeros_like(reward), reward) 65 66 if "ori_reward" in nstate.info: 67 ori_reward = nstate.info.ori_reward 68 ori_reward = jtu.tree_map( 69 where_done, jnp.zeros_like(ori_reward), ori_reward 70 ) 71 else: 72 ori_reward = None 73 74 return nstate, (reward, ori_reward) 75 76 state, (rewards, ori_rewards) = jax.lax.scan( 77 f, state, (), length=self.action_repeat 78 ) 79 80 state = state.replace( 81 reward=jtu.tree_map(jnp.sum, rewards), 82 ) 83 84 if ori_rewards is not None: 85 state = state.replace( 86 info=state.info.replace( 87 ori_reward=jtu.tree_map(jnp.sum, ori_rewards), 88 ), 89 ) 90 91 return state