Source code for evorl.envs.wrappers.action_wrapper
1import jax
2import jax.numpy as jnp
3import jax.tree_util as jtu
4
5from evorl.types import Action
6
7from ..env import Env, EnvState
8from ..space import Box, Space
9from .wrapper import Wrapper
10
11
[docs]
12class ActionSquashWrapper(Wrapper):
13 """Convert continuous action space from [-1, 1] to [low, high]."""
14
15 def __init__(self, env: Env):
16 super().__init__(env)
17
18 # TODO: support pytree action space
19 action_space = self.env.action_space
20 assert isinstance(action_space, Box), "Only support Box action_space"
21
22 self.scale = (action_space.high - action_space.low) * 0.5
23 self.bias = (action_space.high + action_space.low) * 0.5
24
[docs]
25 def step(self, state: EnvState, action: Action) -> EnvState:
26 squashed_action = self.scale * action + self.bias
27 return self.env.step(state, squashed_action)
28
29 @property
30 def action_space(self) -> Space:
31 return Box(low=-jnp.ones_like(self.scale), high=jnp.ones_like(self.scale))
32
33
[docs]
34class ActionRepeatWrapper(Wrapper):
35 """Repeat action for a number of steps.
36
37 :::{note}
38 This wrapper only accumulates `state.reward` and `state.info.ori_reward`. It is safe to use `ActionRepeatWrapper(RewardScaleWrapper(EpisodeWrapper(env)))`. However, if you want accumulate other metrics, inherit this class and add your own logic.
39 :::
40 :::{caution}
41 When using rollout functions like `rollout`, `eval_rollout_episode` with `rollout_length` argument, users should use `math.ceil(env.max_episode_steps/action_repeat)` to match the real rollout_length.
42 :::
43 """
44
45 def __init__(self, env: Env, action_repeat: int):
46 super().__init__(env)
47
48 self.action_repeat = action_repeat
49
[docs]
50 def step(self, state: EnvState, action: Action) -> EnvState:
51 def f(state, _):
52 nstate = self.env.step(state, action)
53
54 def where_done(x, y):
55 done = state.done # prev_done
56 if done.ndim > 0:
57 done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))
58 return jnp.where(done, x, y)
59
60 # when prev_done=True, keep the previous state and set reward=0
61 nstate = jtu.tree_map(where_done, state, nstate)
62
63 reward = nstate.reward
64 reward = jtu.tree_map(where_done, jnp.zeros_like(reward), reward)
65
66 if "ori_reward" in nstate.info:
67 ori_reward = nstate.info.ori_reward
68 ori_reward = jtu.tree_map(
69 where_done, jnp.zeros_like(ori_reward), ori_reward
70 )
71 else:
72 ori_reward = None
73
74 return nstate, (reward, ori_reward)
75
76 state, (rewards, ori_rewards) = jax.lax.scan(
77 f, state, (), length=self.action_repeat
78 )
79
80 state = state.replace(
81 reward=jtu.tree_map(jnp.sum, rewards),
82 )
83
84 if ori_rewards is not None:
85 state = state.replace(
86 info=state.info.replace(
87 ori_reward=jtu.tree_map(jnp.sum, ori_rewards),
88 ),
89 )
90
91 return state