Source code for evorl.envs.wrappers.reward_wrapper
1import chex
2import jax.numpy as jnp
3import jax.tree_util as jtu
4from evorl.types import Action
5
6from ..env import Env, EnvState
7from .wrapper import Wrapper
8
9
[docs]
10class RewardScaleWrapper(Wrapper):
11 """Scale the reward by a factor.
12
13 Usage:
14 - Use EpisodeWrapper(RewardScaleWrapper(env)) to get the scaled `info.episode_return`.
15 - Use RewardScaleWrapper(EpisodeWrapper(env)) to get the original `info.episode_return`.
16 """
17
18 def __init__(self, env: Env, reward_scale: float):
19 super().__init__(env)
20 self.reward_scale = reward_scale
21
[docs]
22 def reset(self, key: chex.PRNGKey) -> EnvState:
23 state = self.env.reset(key)
24 info = state.info.replace(ori_reward=state.reward)
25
26 reward = jtu.tree_map(lambda r: r * self.reward_scale, state.reward)
27
28 return state.replace(reward=reward, info=info)
29
[docs]
30 def step(self, state: EnvState, action: Action) -> EnvState:
31 state = self.env.step(state, action)
32 info = state.info.replace(ori_reward=state.reward)
33
34 reward = jtu.tree_map(lambda r: r * self.reward_scale, state.reward)
35
36 return state.replace(reward=reward, info=info)
37
38
[docs]
39class SparseRewardWrapper(Wrapper):
40 """Convert dense reward to sparse reward.
41
42 The dense rewards become: 0, 0, ..., sum(rewards), 0, 0, ..., sum(rewards)
43 """
44
45 def __init__(self, env: Env, sparse_length: int):
46 super().__init__(env)
47 self.sparse_length = sparse_length
48
[docs]
49 def reset(self, key: chex.PRNGKey) -> EnvState:
50 state = self.env.reset(key)
51
52 state._internal.cum_count = jnp.zeros((), dtype=jnp.int32)
53 state._internal.cum_reward = jtu.tree_map(
54 lambda r: jnp.zeros_like(r), state.reward
55 )
56
57 return state
58
[docs]
59 def step(self, state: EnvState, action: Action) -> EnvState:
60 state = self.env.step(state, action)
61
62 cum_count = state._internal.cum_count
63 cum_reward = state._internal.cum_reward
64
65 cum_count = cum_count + 1
66 cum_reward = jtu.tree_map(lambda x, y: x + y, cum_reward, state.reward)
67 cond = jnp.logical_or(
68 cum_count >= self.sparse_length, state.done.astype(jnp.bool)
69 )
70
71 reward = jtu.tree_map(
72 lambda r: jnp.where(cond, r, jnp.zeros_like(r)), cum_reward
73 )
74
75 # reset cum_reward & cum_count
76 cum_count = jnp.where(cond, jnp.zeros_like(cum_count), cum_count)
77 cum_reward = jtu.tree_map(
78 lambda r: jnp.where(cond, jnp.zeros_like(r), r), cum_reward
79 )
80
81 state = state.replace(
82 reward=reward,
83 _internal=state._internal.replace(
84 cum_count=cum_count, cum_reward=cum_reward
85 ),
86 )
87
88 return state