Source code for evorl.utils.rl_toolkits

  1from functools import partial
  2import chex
  3import jax
  4import jax.numpy as jnp
  5import jax.tree_util as jtu
  6
  7from evorl.sample_batch import SampleBatch
  8from evorl.types import MISSING_REWARD
  9
 10from .jax_utils import is_jitted, sliding_window
 11
 12
[docs] 13def compute_episode_length( 14 dones: chex.Array, # [T, B] 15) -> chex.Array: 16 """Compute the length of the episode. 17 18 Args: 19 dones: Dones collected from an episodic trajectory. 20 21 Returns: 22 Length of the episode. 23 """ 24 # [B] 25 prev_dones = dones[:-1].astype(jnp.int32) 26 return (1 - prev_dones).sum(axis=0) + 1
27 28
[docs] 29def compute_discount_return( 30 rewards: chex.Array, # [T, B] 31 dones: chex.Array, # [T, B] 32 discount: float = 1.0, 33) -> chex.Array: 34 """Compute discount return from the episodic trajectory. 35 36 Args: 37 rewards: Rewards collected from an episodic trajectory. 38 dones: Dones collected from an episodic trajectory. 39 discount: Discount factor. 40 41 Returns: 42 Discounted return. 43 """ 44 45 def _compute_discount_return(discount_return, x_t): 46 # G_t := r_t + γ * G_{t+1} 47 reward_t, discount_t = x_t 48 discount_return = reward_t + discount_return * discount_t 49 50 return discount_return, None 51 52 # [#envs] 53 discount_return = jnp.zeros_like(rewards[0]) 54 55 discount_return, _ = jax.lax.scan( 56 _compute_discount_return, 57 discount_return, 58 (rewards, (1 - dones) * discount), 59 reverse=True, 60 unroll=16, 61 ) 62 63 return discount_return # [B]
64 65
[docs] 66def compute_gae( 67 rewards: jax.Array, # [T, B] 68 values: jax.Array, # [T+1, B] 69 dones: jax.Array, # [T, B] 70 terminations: jax.Array, # [T, B] 71 gae_lambda: float = 1.0, 72 discount: float = 0.99, 73) -> tuple[jax.Array, jax.Array]: 74 """Calculates the Generalized Advantage Estimation (GAE). 75 76 Args: 77 rewards: A float32 tensor of shape [T, B] containing rewards generated by 78 following the behaviour policy. 79 values: A float32 tensor of shape [T+1, B] with the value function estimates 80 wrt. the target policy. `values[T]` is the bootstrap_value 81 dones: A float32 tensor of shape [T, B] with done signal. 82 terminations: A float32 tensor of shape [T, B] with termination signal. 83 gae_lambda: Mix between 1-step (gae_lambda=0) and n-step (gae_lambda=1). 84 discount: TD discount. 85 86 Returns: 87 Tuple: 88 - Lambda returns with shape [T, B], can be used as targets to 89 train a baseline (V(x_t) - vs_t)^2. 90 - Advantages with shape [T, B]. 91 """ 92 rewards_shape = rewards.shape 93 chex.assert_shape(values, (rewards_shape[0] + 1, *rewards_shape[1:])) 94 95 deltas = rewards + discount * (1 - terminations) * values[1:] - values[:-1] 96 97 bootstrap_gae = jnp.zeros_like(values[0]) 98 99 def _compute_gae(gae_t_plus_1, x_t): 100 delta_t, factor_t = x_t 101 gae_t = delta_t + factor_t * gae_t_plus_1 102 103 return gae_t, gae_t 104 105 _, advantages = jax.lax.scan( 106 _compute_gae, 107 bootstrap_gae, 108 (deltas, discount * gae_lambda * (1 - dones)), 109 reverse=True, 110 unroll=16, 111 ) 112 113 lambda_returns = advantages + values[:-1] 114 115 return lambda_returns, advantages
116 117
[docs] 118def compute_gae_with_horizon( 119 rewards: jax.Array, # [T, B] 120 values: jax.Array, # [T+1, B] 121 dones: jax.Array, # [T, B] 122 terminations: jax.Array, # [T, B] 123 gae_horizon: int = 0, 124 gae_lambda: float = 1.0, 125 discount: float = 0.99, 126) -> tuple[jax.Array, jax.Array]: 127 if gae_horizon > 0: 128 assert rewards.shape[0] % gae_horizon == 0 129 # [T,B] -> [T//h, h, B] 130 rewards, dones, terminations = jtu.tree_map( 131 lambda x: x.reshape(-1, gae_horizon, *x.shape[1:]), 132 (rewards, dones, terminations), 133 ) 134 # [T+1,B] -> [T//h, h+1, B] 135 values = sliding_window(values, gae_horizon + 1, gae_horizon) 136 137 v_targets, advantages = jax.vmap( 138 partial( 139 compute_gae, 140 gae_lambda=gae_lambda, 141 discount=discount, 142 ) 143 )(rewards, values, dones, terminations) 144 145 # [T//h, h, B] -> [T, B] 146 v_targets, advantages = jtu.tree_map( 147 lambda x: x.reshape(-1, *x.shape[2:]), 148 (v_targets, advantages), 149 ) 150 else: 151 v_targets, advantages = compute_gae( 152 rewards, 153 values, 154 dones, 155 terminations, 156 gae_lambda=gae_lambda, 157 discount=discount, 158 ) 159 160 return v_targets, advantages
161 162
[docs] 163def shuffle_sample_batch(sample_batch: SampleBatch, key: chex.PRNGKey): 164 """Shuffle the sample batch.""" 165 return jtu.tree_map(lambda x: jax.random.permutation(key, x), sample_batch)
166 167
[docs] 168def soft_target_update(target_params, source_params, tau: float): 169 """Perform soft update on target network. 170 171 Args: 172 target_params: Target network parameters. 173 source_params: Source network parameters. 174 tau: Interpolation factor. 175 176 Returns: 177 Updated target network parameters. 178 """ 179 return jtu.tree_map( 180 lambda target, source: tau * source + (1 - tau) * target, 181 target_params, 182 source_params, 183 )
184 185
[docs] 186def flatten_rollout_trajectory(trajectory: SampleBatch) -> SampleBatch: 187 """Flatten the trajectory from [T, B, ...] to [T*B, ...].""" 188 return jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 2), trajectory)
189 190
[docs] 191def flatten_pop_rollout_episode(trajectory: SampleBatch): 192 """Flatten the trajectory from [#pop, T, B, ...] to [T, #pop*B, ...].""" 193 return jtu.tree_map(lambda x: jax.lax.collapse(x.swapaxes(0, 1), 1, 3), trajectory)
194 195
[docs] 196def average_episode_discount_return( 197 episode_discount_return: jax.Array, # [T,B] 198 dones: jax.Array, # [T,B] 199 dp_axis_name: str | None = None, 200) -> jax.Array: 201 """Estimate the average episode return from a segmented trajectory. 202 203 This method does not require the trajectory data from a complete episode. 204 """ 205 cnt = dones.sum() 206 episode_discount_return_sum = (episode_discount_return * dones).sum() 207 208 if dp_axis_name is not None: 209 episode_discount_return_sum = jax.lax.psum( 210 episode_discount_return_sum, dp_axis_name 211 ) 212 cnt = jax.lax.psum(cnt, dp_axis_name) 213 214 return jnp.where( 215 jnp.isclose(cnt, 0), 216 jnp.full_like(episode_discount_return_sum, MISSING_REWARD), 217 episode_discount_return_sum / cnt, 218 )
219 220
[docs] 221def approximate_kl(logratio: jax.Array, mode="k3", axis=-1) -> jax.Array: 222 """Approximate KL divergence by K3 estimator (no bias, low variance). 223 224 See http://joschu.net/blog/kl-approx.html 225 226 Args: 227 logratio: ratio of p(x)/q(x), where x are sampled from q(x) 228 229 Returns: 230 Approximated KL(q||p) (Forward KL) 231 """ 232 if mode == "k1": 233 approx_kl = -jnp.mean(logratio, axis=axis) 234 elif mode == "k2": 235 approx_kl = jnp.mean(0.5 * jnp.square(logratio), axis=axis) 236 elif mode == "k3": 237 ratio = jnp.exp(logratio) 238 approx_kl = jnp.mean((ratio - 1) * logratio, axis=axis) 239 return approx_kl
240 241
[docs] 242def fold_multi_steps(step_fn, num_steps): 243 """Fold multiple steps into a single step function.""" 244 245 def _multi_steps(state): 246 def _step(state, unused_t): 247 train_metrics, state = step_fn(state) 248 return state, train_metrics 249 250 state, train_metrics_arr = jax.lax.scan(_step, state, (), length=num_steps) 251 252 return train_metrics_arr, state 253 254 if is_jitted(step_fn): 255 _multi_steps = jax.jit(_multi_steps) 256 257 return _multi_steps