Source code for evorl.ec.optimizers.vanilla_es

  1import chex
  2import jax
  3import jax.numpy as jnp
  4import jax.tree_util as jtu
  5import optax
  6
  7from evorl.types import PyTreeData, Params, PyTreeDict, pytree_field
  8from evorl.utils.jax_utils import rng_split_like_tree
  9
 10from .utils import weight_sum, ExponentialScheduleSpec
 11from .ec_optimizer import EvoOptimizer
 12
 13
[docs] 14class VanillaESState(PyTreeData): 15 """State of the VanillaES.""" 16 17 mean: chex.ArrayTree 18 noise_std: chex.Array 19 key: chex.PRNGKey 20 noise: None | chex.ArrayTree = None
21 22
[docs] 23class VanillaES(EvoOptimizer): 24 """Canonical Evolution Strategies. 25 26 Paper: [Back to basics: Benchmarking canonical evolution strategies for playing atari](https://arxiv.org/abs/1802.08842) 27 """ 28 29 pop_size: int 30 num_elites: int 31 noise_std_schedule: ExponentialScheduleSpec 32 elite_weights: chex.Array = pytree_field(init=False) 33 34 def __post_init__(self): 35 elite_weights = jnp.log(self.num_elites + 0.5) - jnp.log( 36 jnp.arange(1, self.num_elites + 1) 37 ) 38 self.elite_weights = elite_weights / elite_weights.sum() 39
[docs] 40 def init(self, mean: Params, key: chex.PRNGKey) -> VanillaESState: 41 return VanillaESState( 42 mean=mean, 43 noise_std=jnp.float32(self.noise_std_schedule.init), 44 key=key, 45 )
46
[docs] 47 def ask(self, state: VanillaESState) -> tuple[Params, VanillaESState]: 48 key, sample_key = jax.random.split(state.key) 49 sample_keys = rng_split_like_tree(sample_key, state.mean) 50 51 noise = jtu.tree_map( 52 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape)) 53 * state.noise_std, 54 state.mean, 55 sample_keys, 56 ) 57 58 pop = jtu.tree_map( 59 lambda m, z: m + z, 60 state.mean, 61 noise, 62 ) 63 return pop, state.replace(key=key, noise=noise)
64
[docs] 65 def tell( 66 self, state: VanillaESState, fitnesses: chex.Array 67 ) -> tuple[PyTreeDict, VanillaESState]: 68 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1] 69 70 mean = jtu.tree_map( 71 lambda x, z: x + weight_sum(z[elite_indices], self.elite_weights), 72 state.mean, 73 state.noise, 74 ) 75 76 noise_std = optax.incremental_update( 77 self.noise_std_schedule.final, 78 state.noise_std, 79 1 - self.noise_std_schedule.decay, 80 ) 81 82 return PyTreeDict(), state.replace(mean=mean, noise_std=noise_std, noise=None)
83 84
[docs] 85class VanillaESMod(VanillaES): 86 """Variant of VanillaES. 87 88 Add `external_size` number of external individuals and corresponding fitnesses during the ES update by `tell_external()` 89 90 Attributes: 91 external_size: number of external individuals 92 mix_strategy: strategy to mix external individuals with the elites. 93 - "always": always mix external individuals with elites 94 - "normal": concat external individuals to the population and select `num_elites` elites from the combined population. 95 """ 96 97 external_size: int 98 mix_strategy: str = "always" 99 100 def __post_init__(self): 101 super().__post_init__() 102 assert self.num_elites >= self.external_size 103 assert self.mix_strategy in ["always", "normal"] 104
[docs] 105 def tell_external( 106 self, state: VanillaESState, fitnesses: chex.Array 107 ) -> tuple[PyTreeDict, VanillaESState]: 108 chex.assert_shape(fitnesses, (self.pop_size + self.external_size,)) 109 chex.assert_tree_shape_prefix( 110 state.noise, (self.pop_size + self.external_size,) 111 ) 112 113 if self.mix_strategy == "always": 114 # select (self.num_elites-self.external_size) elites from pop 115 # then insert all external individuals and sort them. 116 117 # Note: user should ensure external individuals and fitnesses are concated behind the pop. 118 # TODO: need to improve 119 pop_fitnesses = fitnesses[: self.pop_size] 120 external_fitnesses = fitnesses[self.pop_size :] 121 122 pop_elite_fitnesses, pop_elite_indices = jax.lax.top_k( 123 pop_fitnesses, self.num_elites - self.external_size 124 ) 125 126 elite_fitnesses = jnp.concatenate([pop_elite_fitnesses, external_fitnesses]) 127 elite_indices = jnp.concatenate( 128 [ 129 pop_elite_indices, 130 jnp.arange( 131 self.pop_size, 132 self.pop_size + self.external_size, 133 dtype=jnp.int32, 134 ), 135 ] 136 ) 137 138 elite_indices = elite_indices[jnp.argsort(elite_fitnesses, descending=True)] 139 else: 140 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1] 141 142 mean = jtu.tree_map( 143 lambda x, z: x + weight_sum(z[elite_indices], self.elite_weights), 144 state.mean, 145 state.noise, 146 ) 147 148 noise_std = optax.incremental_update( 149 self.noise_std_schedule.final, 150 state.noise_std, 151 1 - self.noise_std_schedule.decay, 152 ) 153 154 return PyTreeDict(), state.replace(mean=mean, noise_std=noise_std, noise=None)