Source code for evorl.ec.optimizers.utils
1import chex
2import jax
3import jax.numpy as jnp
4import optax
5
6from evorl.types import PyTreeData
7
8
9optimizer_map = dict(
10 adam=optax.adam,
11 sgd=optax.sgd,
12 rmsprop=optax.rmsprop,
13)
14
15
[docs]
16class ExponentialScheduleSpec(PyTreeData):
17 """Specification for an exponential schedule for HyperParam."""
18
19 init: float
20 final: float
21 decay: float
22
23
[docs]
24def weight_sum(x: jax.Array, w: jax.Array) -> jax.Array:
25 """Weighted sum.
26
27 Args:
28 x: (n, ...)
29 w: (n,)
30 """
31 chex.assert_equal_shape_prefix((x, w), 1)
32 assert w.ndim == 1
33
34 w = w.reshape(w.shape + (1,) * (x.ndim - 1))
35 return jnp.sum(x * w, axis=0)