Source code for evorl.ec.optimizers.ars

 1import chex
 2import jax
 3import jax.numpy as jnp
 4import jax.tree_util as jtu
 5import optax
 6
 7from evorl.types import PyTreeData, PyTreeDict, Params, pytree_field
 8from evorl.utils.jax_utils import rng_split_like_tree
 9
10from .utils import weight_sum, optimizer_map
11from .ec_optimizer import EvoOptimizer, ECState
12
13
[docs] 14class ARSState(PyTreeData): 15 """State of the ARS.""" 16 17 mean: chex.ArrayTree 18 opt_state: optax.OptState 19 key: chex.PRNGKey 20 noise: None | chex.ArrayTree = None
21 22
[docs] 23class ARS(EvoOptimizer): 24 """Augmented Random Search. 25 26 Paper: [Simple random search of static linear policies is competitive for reinforcement learning](https://proceedings.neurips.cc/paper_files/paper/2018/file/7634ea65a4e6d9041cfd3f7de18e334a-Paper.pdf) 27 """ 28 29 pop_size: int 30 num_elites: int 31 lr: float 32 noise_std: float 33 fitness_std_eps: float = 1e-8 34 optimizer_name: str = "sgd" 35 36 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False) 37 38 def __post_init__(self): 39 assert self.pop_size > 0 and self.pop_size % 2 == 0, ( 40 "pop_size must be positive even number" 41 ) 42 43 self.optimizer = optimizer_map[self.optimizer_name](learning_rate=self.lr) 44
[docs] 45 def init(self, mean: Params, key: chex.PRNGKey) -> ARSState: 46 opt_state = self.optimizer.init(mean) 47 return ARSState(mean=mean, opt_state=opt_state, key=key)
48
[docs] 49 def ask(self, state: ARSState) -> tuple[Params, ECState]: 50 key, sample_key = jax.random.split(state.key) 51 sample_keys = rng_split_like_tree(sample_key, state.mean) 52 53 half_noise = jtu.tree_map( 54 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)), 55 state.mean, 56 sample_keys, 57 ) 58 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), half_noise) 59 60 pop = jtu.tree_map( 61 lambda m, z: m + self.noise_std * z, 62 state.mean, 63 noise, 64 ) 65 return pop, state.replace(key=key, noise=half_noise)
66
[docs] 67 def tell( 68 self, state: ARSState, fitnesses: chex.Array 69 ) -> tuple[PyTreeDict, ARSState]: 70 half_pop_size = self.pop_size // 2 71 72 fit_p = fitnesses[:half_pop_size] # r_positive 73 fit_n = fitnesses[half_pop_size:] # r_negtive 74 elite_indices = jax.lax.top_k(jnp.maximum(fit_p, fit_n), self.num_elites)[1] 75 76 fitnesses_elite = jnp.concatenate([fit_p[elite_indices], fit_n[elite_indices]]) 77 # Add small constant to ensure non-zero division stability 78 fitness_std = jnp.std(fitnesses_elite) + self.fitness_std_eps 79 80 fit_diff = (fit_p[elite_indices] - fit_n[elite_indices]) / fitness_std 81 82 grad = jtu.tree_map( 83 # Note: we need additional "-1.0" since we are maximizing the fitness 84 lambda z: (-weight_sum(z[elite_indices], fit_diff) / (self.num_elites)), 85 state.noise, 86 ) 87 88 update, opt_state = self.optimizer.update(grad, state.opt_state) 89 mean = optax.apply_updates(state.mean, update) 90 91 ec_metrics = PyTreeDict(fitness_std=fitness_std) 92 93 return ec_metrics, state.replace(mean=mean, opt_state=opt_state, noise=None)