Source code for evorl.ec.optimizers.ars
1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5import optax
6
7from evorl.types import PyTreeData, PyTreeDict, Params, pytree_field
8from evorl.utils.jax_utils import rng_split_like_tree
9
10from .utils import weight_sum, optimizer_map
11from .ec_optimizer import EvoOptimizer, ECState
12
13
21
22
[docs]
23class ARS(EvoOptimizer):
24 """Augmented Random Search.
25
26 Paper: [Simple random search of static linear policies is competitive for reinforcement learning](https://proceedings.neurips.cc/paper_files/paper/2018/file/7634ea65a4e6d9041cfd3f7de18e334a-Paper.pdf)
27 """
28
29 pop_size: int
30 num_elites: int
31 lr: float
32 noise_std: float
33 fitness_std_eps: float = 1e-8
34 optimizer_name: str = "sgd"
35
36 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False)
37
38 def __post_init__(self):
39 assert self.pop_size > 0 and self.pop_size % 2 == 0, (
40 "pop_size must be positive even number"
41 )
42
43 self.optimizer = optimizer_map[self.optimizer_name](learning_rate=self.lr)
44
[docs]
45 def init(self, mean: Params, key: chex.PRNGKey) -> ARSState:
46 opt_state = self.optimizer.init(mean)
47 return ARSState(mean=mean, opt_state=opt_state, key=key)
48
[docs]
49 def ask(self, state: ARSState) -> tuple[Params, ECState]:
50 key, sample_key = jax.random.split(state.key)
51 sample_keys = rng_split_like_tree(sample_key, state.mean)
52
53 half_noise = jtu.tree_map(
54 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)),
55 state.mean,
56 sample_keys,
57 )
58 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), half_noise)
59
60 pop = jtu.tree_map(
61 lambda m, z: m + self.noise_std * z,
62 state.mean,
63 noise,
64 )
65 return pop, state.replace(key=key, noise=half_noise)
66
[docs]
67 def tell(
68 self, state: ARSState, fitnesses: chex.Array
69 ) -> tuple[PyTreeDict, ARSState]:
70 half_pop_size = self.pop_size // 2
71
72 fit_p = fitnesses[:half_pop_size] # r_positive
73 fit_n = fitnesses[half_pop_size:] # r_negtive
74 elite_indices = jax.lax.top_k(jnp.maximum(fit_p, fit_n), self.num_elites)[1]
75
76 fitnesses_elite = jnp.concatenate([fit_p[elite_indices], fit_n[elite_indices]])
77 # Add small constant to ensure non-zero division stability
78 fitness_std = jnp.std(fitnesses_elite) + self.fitness_std_eps
79
80 fit_diff = (fit_p[elite_indices] - fit_n[elite_indices]) / fitness_std
81
82 grad = jtu.tree_map(
83 # Note: we need additional "-1.0" since we are maximizing the fitness
84 lambda z: (-weight_sum(z[elite_indices], fit_diff) / (self.num_elites)),
85 state.noise,
86 )
87
88 update, opt_state = self.optimizer.update(grad, state.opt_state)
89 mean = optax.apply_updates(state.mean, update)
90
91 ec_metrics = PyTreeDict(fitness_std=fitness_std)
92
93 return ec_metrics, state.replace(mean=mean, opt_state=opt_state, noise=None)