1from functools import partial
2import chex
3import jax
4import jax.numpy as jnp
5import jax.tree_util as jtu
6
7from evorl.sample_batch import SampleBatch
8from evorl.types import MISSING_REWARD
9
10from .jax_utils import is_jitted, sliding_window
11
12
[docs]
13def compute_episode_length(
14 dones: chex.Array, # [T, B]
15) -> chex.Array:
16 """Compute the length of the episode.
17
18 Args:
19 dones: Dones collected from an episodic trajectory.
20
21 Returns:
22 Length of the episode.
23 """
24 # [B]
25 prev_dones = dones[:-1].astype(jnp.int32)
26 return (1 - prev_dones).sum(axis=0) + 1
27
28
[docs]
29def compute_discount_return(
30 rewards: chex.Array, # [T, B]
31 dones: chex.Array, # [T, B]
32 discount: float = 1.0,
33) -> chex.Array:
34 """Compute discount return from the episodic trajectory.
35
36 Args:
37 rewards: Rewards collected from an episodic trajectory.
38 dones: Dones collected from an episodic trajectory.
39 discount: Discount factor.
40
41 Returns:
42 Discounted return.
43 """
44
45 def _compute_discount_return(discount_return, x_t):
46 # G_t := r_t + γ * G_{t+1}
47 reward_t, discount_t = x_t
48 discount_return = reward_t + discount_return * discount_t
49
50 return discount_return, None
51
52 # [#envs]
53 discount_return = jnp.zeros_like(rewards[0])
54
55 discount_return, _ = jax.lax.scan(
56 _compute_discount_return,
57 discount_return,
58 (rewards, (1 - dones) * discount),
59 reverse=True,
60 unroll=16,
61 )
62
63 return discount_return # [B]
64
65
[docs]
66def compute_gae(
67 rewards: jax.Array, # [T, B]
68 values: jax.Array, # [T+1, B]
69 dones: jax.Array, # [T, B]
70 terminations: jax.Array, # [T, B]
71 gae_lambda: float = 1.0,
72 discount: float = 0.99,
73) -> tuple[jax.Array, jax.Array]:
74 """Calculates the Generalized Advantage Estimation (GAE).
75
76 Args:
77 rewards: A float32 tensor of shape [T, B] containing rewards generated by
78 following the behaviour policy.
79 values: A float32 tensor of shape [T+1, B] with the value function estimates
80 wrt. the target policy. `values[T]` is the bootstrap_value
81 dones: A float32 tensor of shape [T, B] with done signal.
82 terminations: A float32 tensor of shape [T, B] with termination signal.
83 gae_lambda: Mix between 1-step (gae_lambda=0) and n-step (gae_lambda=1).
84 discount: TD discount.
85
86 Returns:
87 Tuple:
88 - Lambda returns with shape [T, B], can be used as targets to
89 train a baseline (V(x_t) - vs_t)^2.
90 - Advantages with shape [T, B].
91 """
92 rewards_shape = rewards.shape
93 chex.assert_shape(values, (rewards_shape[0] + 1, *rewards_shape[1:]))
94
95 deltas = rewards + discount * (1 - terminations) * values[1:] - values[:-1]
96
97 bootstrap_gae = jnp.zeros_like(values[0])
98
99 def _compute_gae(gae_t_plus_1, x_t):
100 delta_t, factor_t = x_t
101 gae_t = delta_t + factor_t * gae_t_plus_1
102
103 return gae_t, gae_t
104
105 _, advantages = jax.lax.scan(
106 _compute_gae,
107 bootstrap_gae,
108 (deltas, discount * gae_lambda * (1 - dones)),
109 reverse=True,
110 unroll=16,
111 )
112
113 lambda_returns = advantages + values[:-1]
114
115 return lambda_returns, advantages
116
117
[docs]
118def compute_gae_with_horizon(
119 rewards: jax.Array, # [T, B]
120 values: jax.Array, # [T+1, B]
121 dones: jax.Array, # [T, B]
122 terminations: jax.Array, # [T, B]
123 gae_horizon: int = 0,
124 gae_lambda: float = 1.0,
125 discount: float = 0.99,
126) -> tuple[jax.Array, jax.Array]:
127 if gae_horizon > 0:
128 assert rewards.shape[0] % gae_horizon == 0
129 # [T,B] -> [T//h, h, B]
130 rewards, dones, terminations = jtu.tree_map(
131 lambda x: x.reshape(-1, gae_horizon, *x.shape[1:]),
132 (rewards, dones, terminations),
133 )
134 # [T+1,B] -> [T//h, h+1, B]
135 values = sliding_window(values, gae_horizon + 1, gae_horizon)
136
137 v_targets, advantages = jax.vmap(
138 partial(
139 compute_gae,
140 gae_lambda=gae_lambda,
141 discount=discount,
142 )
143 )(rewards, values, dones, terminations)
144
145 # [T//h, h, B] -> [T, B]
146 v_targets, advantages = jtu.tree_map(
147 lambda x: x.reshape(-1, *x.shape[2:]),
148 (v_targets, advantages),
149 )
150 else:
151 v_targets, advantages = compute_gae(
152 rewards,
153 values,
154 dones,
155 terminations,
156 gae_lambda=gae_lambda,
157 discount=discount,
158 )
159
160 return v_targets, advantages
161
162
[docs]
163def shuffle_sample_batch(sample_batch: SampleBatch, key: chex.PRNGKey):
164 """Shuffle the sample batch."""
165 return jtu.tree_map(lambda x: jax.random.permutation(key, x), sample_batch)
166
167
[docs]
168def soft_target_update(target_params, source_params, tau: float):
169 """Perform soft update on target network.
170
171 Args:
172 target_params: Target network parameters.
173 source_params: Source network parameters.
174 tau: Interpolation factor.
175
176 Returns:
177 Updated target network parameters.
178 """
179 return jtu.tree_map(
180 lambda target, source: tau * source + (1 - tau) * target,
181 target_params,
182 source_params,
183 )
184
185
[docs]
186def flatten_rollout_trajectory(trajectory: SampleBatch) -> SampleBatch:
187 """Flatten the trajectory from [T, B, ...] to [T*B, ...]."""
188 return jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 2), trajectory)
189
190
[docs]
191def flatten_pop_rollout_episode(trajectory: SampleBatch):
192 """Flatten the trajectory from [#pop, T, B, ...] to [T, #pop*B, ...]."""
193 return jtu.tree_map(lambda x: jax.lax.collapse(x.swapaxes(0, 1), 1, 3), trajectory)
194
195
[docs]
196def average_episode_discount_return(
197 episode_discount_return: jax.Array, # [T,B]
198 dones: jax.Array, # [T,B]
199 dp_axis_name: str | None = None,
200) -> jax.Array:
201 """Estimate the average episode return from a segmented trajectory.
202
203 This method does not require the trajectory data from a complete episode.
204 """
205 cnt = dones.sum()
206 episode_discount_return_sum = (episode_discount_return * dones).sum()
207
208 if dp_axis_name is not None:
209 episode_discount_return_sum = jax.lax.psum(
210 episode_discount_return_sum, dp_axis_name
211 )
212 cnt = jax.lax.psum(cnt, dp_axis_name)
213
214 return jnp.where(
215 jnp.isclose(cnt, 0),
216 jnp.full_like(episode_discount_return_sum, MISSING_REWARD),
217 episode_discount_return_sum / cnt,
218 )
219
220
[docs]
221def approximate_kl(logratio: jax.Array, mode="k3", axis=-1) -> jax.Array:
222 """Approximate KL divergence by K3 estimator (no bias, low variance).
223
224 See http://joschu.net/blog/kl-approx.html
225
226 Args:
227 logratio: ratio of p(x)/q(x), where x are sampled from q(x)
228
229 Returns:
230 Approximated KL(q||p) (Forward KL)
231 """
232 if mode == "k1":
233 approx_kl = -jnp.mean(logratio, axis=axis)
234 elif mode == "k2":
235 approx_kl = jnp.mean(0.5 * jnp.square(logratio), axis=axis)
236 elif mode == "k3":
237 ratio = jnp.exp(logratio)
238 approx_kl = jnp.mean((ratio - 1) * logratio, axis=axis)
239 return approx_kl
240
241
[docs]
242def fold_multi_steps(step_fn, num_steps):
243 """Fold multiple steps into a single step function."""
244
245 def _multi_steps(state):
246 def _step(state, unused_t):
247 train_metrics, state = step_fn(state)
248 return state, train_metrics
249
250 state, train_metrics_arr = jax.lax.scan(_step, state, (), length=num_steps)
251
252 return train_metrics_arr, state
253
254 if is_jitted(step_fn):
255 _multi_steps = jax.jit(_multi_steps)
256
257 return _multi_steps