Source code for evorl.envs.wrappers.reward_wrapper

 1import chex
 2import jax.numpy as jnp
 3import jax.tree_util as jtu
 4from evorl.types import Action
 5
 6from ..env import Env, EnvState
 7from .wrapper import Wrapper
 8
 9
[docs] 10class RewardScaleWrapper(Wrapper): 11 """Scale the reward by a factor. 12 13 Usage: 14 - Use EpisodeWrapper(RewardScaleWrapper(env)) to get the scaled `info.episode_return`. 15 - Use RewardScaleWrapper(EpisodeWrapper(env)) to get the original `info.episode_return`. 16 """ 17 18 def __init__(self, env: Env, reward_scale: float): 19 super().__init__(env) 20 self.reward_scale = reward_scale 21
[docs] 22 def reset(self, key: chex.PRNGKey) -> EnvState: 23 state = self.env.reset(key) 24 info = state.info.replace(ori_reward=state.reward) 25 26 reward = jtu.tree_map(lambda r: r * self.reward_scale, state.reward) 27 28 return state.replace(reward=reward, info=info)
29
[docs] 30 def step(self, state: EnvState, action: Action) -> EnvState: 31 state = self.env.step(state, action) 32 info = state.info.replace(ori_reward=state.reward) 33 34 reward = jtu.tree_map(lambda r: r * self.reward_scale, state.reward) 35 36 return state.replace(reward=reward, info=info)
37 38
[docs] 39class SparseRewardWrapper(Wrapper): 40 """Convert dense reward to sparse reward. 41 42 The dense rewards become: 0, 0, ..., sum(rewards), 0, 0, ..., sum(rewards) 43 """ 44 45 def __init__(self, env: Env, sparse_length: int): 46 super().__init__(env) 47 self.sparse_length = sparse_length 48
[docs] 49 def reset(self, key: chex.PRNGKey) -> EnvState: 50 state = self.env.reset(key) 51 52 state._internal.cum_count = jnp.zeros((), dtype=jnp.int32) 53 state._internal.cum_reward = jtu.tree_map( 54 lambda r: jnp.zeros_like(r), state.reward 55 ) 56 57 return state
58
[docs] 59 def step(self, state: EnvState, action: Action) -> EnvState: 60 state = self.env.step(state, action) 61 62 cum_count = state._internal.cum_count 63 cum_reward = state._internal.cum_reward 64 65 cum_count = cum_count + 1 66 cum_reward = jtu.tree_map(lambda x, y: x + y, cum_reward, state.reward) 67 cond = jnp.logical_or( 68 cum_count >= self.sparse_length, state.done.astype(jnp.bool) 69 ) 70 71 reward = jtu.tree_map( 72 lambda r: jnp.where(cond, r, jnp.zeros_like(r)), cum_reward 73 ) 74 75 # reset cum_reward & cum_count 76 cum_count = jnp.where(cond, jnp.zeros_like(cum_count), cum_count) 77 cum_reward = jtu.tree_map( 78 lambda r: jnp.where(cond, jnp.zeros_like(r), r), cum_reward 79 ) 80 81 state = state.replace( 82 reward=reward, 83 _internal=state._internal.replace( 84 cum_count=cum_count, cum_reward=cum_reward 85 ), 86 ) 87 88 return state