Source code for evorl.ec.optimizers.utils

 1import chex
 2import jax
 3import jax.numpy as jnp
 4import optax
 5
 6from evorl.types import PyTreeData
 7
 8
 9optimizer_map = dict(
10    adam=optax.adam,
11    sgd=optax.sgd,
12    rmsprop=optax.rmsprop,
13)
14
15
[docs] 16class ExponentialScheduleSpec(PyTreeData): 17 """Specification for an exponential schedule for HyperParam.""" 18 19 init: float 20 final: float 21 decay: float
22 23
[docs] 24def weight_sum(x: jax.Array, w: jax.Array) -> jax.Array: 25 """Weighted sum. 26 27 Args: 28 x: (n, ...) 29 w: (n,) 30 """ 31 chex.assert_equal_shape_prefix((x, w), 1) 32 assert w.ndim == 1 33 34 w = w.reshape(w.shape + (1,) * (x.ndim - 1)) 35 return jnp.sum(x * w, axis=0)