Source code for evorl.ec.optimizers.openes

  1from collections.abc import Callable
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7import optax
  8
  9from evorl.types import PyTreeData, pytree_field, Params, PyTreeDict
 10from evorl.utils.jax_utils import rng_split_like_tree, invert_permutation
 11from evorl.utils.ec_utils import ParamVectorSpec
 12
 13from .utils import ExponentialScheduleSpec, weight_sum, optimizer_map
 14from .ec_optimizer import EvoOptimizer, ECState
 15
 16
[docs] 17def compute_ranks(x): 18 """Get ranks in [0, len(x)-1]. 19 20 This is different from `scipy.stats.rankdata`, which returns ranks in [1, len(x)]. 21 """ 22 assert x.ndim == 1 23 ranks = invert_permutation(jnp.argsort(x)) 24 return ranks
25 26
[docs] 27def compute_centered_ranks(x): 28 """Get centered ranks in [-0.5, 0.5].""" 29 y = compute_ranks(x) 30 y /= x.size - 1 31 y -= 0.5 32 return y
33 34
[docs] 35class OpenESState(PyTreeData): 36 """State of the OpenES.""" 37 38 mean: chex.ArrayTree 39 opt_state: optax.OptState 40 noise_std: chex.Array 41 key: chex.PRNGKey 42 noise: None | chex.ArrayTree = None
43 44
[docs] 45class OpenES(EvoOptimizer): 46 """OpenAI ES.""" 47 48 pop_size: int 49 lr_schedule: ExponentialScheduleSpec 50 noise_std_schedule: ExponentialScheduleSpec 51 mirror_sampling: bool = True 52 optimizer_name: str = "adam" 53 weight_decay: float | None = None 54 55 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field( 56 static=True, default=compute_centered_ranks 57 ) 58 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False) 59 60 def __post_init__(self): 61 assert self.pop_size > 0, "pop_size must be positive" 62 if self.mirror_sampling: 63 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling" 64 65 self.optimizer = optax.inject_hyperparams(optimizer_map[self.optimizer_name])( 66 learning_rate=self.lr_schedule.init 67 ) 68
[docs] 69 def init(self, mean: Params, key: chex.PRNGKey) -> ECState: 70 return OpenESState( 71 mean=mean, 72 opt_state=self.optimizer.init(mean), 73 noise_std=jnp.float32(self.noise_std_schedule.init), 74 key=key, 75 )
76
[docs] 77 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]: 78 """Generate new candidate solutions.""" 79 key, sample_key = jax.random.split(state.key) 80 sample_keys = rng_split_like_tree(sample_key, state.mean) 81 82 if self.mirror_sampling: 83 noise = jtu.tree_map( 84 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)), 85 state.mean, 86 sample_keys, 87 ) 88 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise) 89 else: 90 noise = jtu.tree_map( 91 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape)), 92 state.mean, 93 sample_keys, 94 ) 95 96 pop = jtu.tree_map( 97 lambda m, z: m + state.noise_std * z, 98 state.mean, 99 noise, 100 ) 101 state = state.replace(key=key, noise=noise) 102 103 return pop, state
104
[docs] 105 def tell( 106 self, state: ECState, fitnesses: chex.Array 107 ) -> tuple[PyTreeDict, OpenESState]: 108 """Update the optimizer state based on the fitnesses of the candidate solutions.""" 109 transformed_fitnesses = self.fitness_shaping_fn(fitnesses) 110 111 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m)) 112 grad = jtu.tree_map( 113 # Note: we need additional "-1.0" since we are maximizing the fitness 114 lambda z: ( 115 -weight_sum(z, transformed_fitnesses) 116 / (self.pop_size * state.noise_std) 117 ), 118 state.noise, 119 ) 120 121 # add L2 weight decay 122 if self.weight_decay is not None: 123 grad = jtu.tree_map( 124 lambda g, x: g + self.weight_decay * x, 125 grad, 126 state.mean, 127 ) 128 129 update, opt_state = self.optimizer.update(grad, state.opt_state) 130 mean = optax.apply_updates(state.mean, update) 131 132 opt_state.hyperparams["learning_rate"] = optax.incremental_update( 133 self.lr_schedule.final, 134 opt_state.hyperparams["learning_rate"], 135 1 - self.lr_schedule.decay, 136 ) 137 138 noise_std = optax.incremental_update( 139 self.noise_std_schedule.final, 140 state.noise_std, 141 1 - self.noise_std_schedule.decay, 142 ) 143 144 return PyTreeDict(), state.replace( 145 mean=mean, opt_state=opt_state, noise_std=noise_std, noise=None 146 )
147 148
[docs] 149class OpenESNoiseTableState(PyTreeData): 150 """State of the OpenES with noise table.""" 151 152 mean: chex.ArrayTree 153 opt_state: optax.OptState 154 noise_std: chex.Array 155 noise_table: chex.ArrayTree 156 key: chex.PRNGKey 157 noise: None | chex.ArrayTree = None
158 159
[docs] 160class OpenESNoiseTable(EvoOptimizer): 161 """OpenAI ES with noise table.""" 162 163 pop_size: int 164 noise_table_size: int 165 lr_schedule: ExponentialScheduleSpec 166 noise_std_schedule: ExponentialScheduleSpec 167 mirror_sampling: bool = True 168 optimizer_name: str = "adam" 169 weight_decay: float | None = None 170 171 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field( 172 static=True, default=compute_centered_ranks 173 ) 174 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False) 175 176 def __post_init__(self): 177 assert self.pop_size > 0, "pop_size must be positive" 178 if self.mirror_sampling: 179 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling" 180 181 self.optimizer = optax.inject_hyperparams(optimizer_map[self.optimizer_name])( 182 learning_rate=self.lr_schedule.init 183 ) 184
[docs] 185 def init(self, mean: Params, key: chex.PRNGKey) -> ECState: 186 key, noise_table_key = jax.random.split(key) 187 noise_table = jax.random.normal(noise_table_key, shape=(self.noise_table_size,)) 188 189 return OpenESNoiseTableState( 190 mean=mean, 191 opt_state=self.optimizer.init(mean), 192 noise_std=jnp.float32(self.noise_std_schedule.init), 193 noise_table=noise_table, 194 key=key, 195 )
196
[docs] 197 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]: 198 """Generate new candidate solutions.""" 199 key, sample_key = jax.random.split(state.key) 200 # sample_keys = rng_split_like_tree(sample_key, state.mean) 201 202 param_vec_spec = ParamVectorSpec(state.mean) 203 204 def sample_from_noise_table(idx): 205 return jax.lax.dynamic_slice_in_dim( 206 state.noise_table, idx, param_vec_spec.vec_size, axis=0 207 ) 208 209 if self.mirror_sampling: 210 noise_idx = jax.random.randint( 211 sample_key, 212 shape=(self.pop_size // 2,), 213 minval=0, 214 maxval=self.noise_table_size - param_vec_spec.vec_size, 215 ) 216 noise = param_vec_spec.to_tree(jax.vmap(sample_from_noise_table)(noise_idx)) 217 218 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise) 219 else: 220 noise_idx = jax.random.randint( 221 sample_key, 222 shape=(self.pop_size,), 223 minval=0, 224 maxval=self.noise_table_size - param_vec_spec.vec_size, 225 ) 226 noise = param_vec_spec.to_tree(jax.vmap(sample_from_noise_table)(noise_idx)) 227 228 pop = jtu.tree_map( 229 lambda m, z: m + state.noise_std * z, 230 state.mean, 231 noise, 232 ) 233 state = state.replace(key=key, noise=noise) 234 235 return pop, state
236
[docs] 237 def tell( 238 self, state: ECState, fitnesses: chex.Array 239 ) -> tuple[PyTreeDict, OpenESState]: 240 """Update the optimizer state based on the fitnesses of the candidate solutions.""" 241 transformed_fitnesses = self.fitness_shaping_fn(fitnesses) 242 243 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m)) 244 grad = jtu.tree_map( 245 # Note: we need additional "-1.0" since we are maximizing the fitness 246 lambda z: ( 247 -weight_sum(z, transformed_fitnesses) 248 / (self.pop_size * state.noise_std) 249 ), 250 state.noise, 251 ) 252 253 # add L2 weight decay 254 if self.weight_decay is not None: 255 grad = jtu.tree_map( 256 lambda g, x: g + self.weight_decay * x, 257 grad, 258 state.mean, 259 ) 260 261 update, opt_state = self.optimizer.update(grad, state.opt_state) 262 mean = optax.apply_updates(state.mean, update) 263 264 opt_state.hyperparams["learning_rate"] = optax.incremental_update( 265 self.lr_schedule.final, 266 opt_state.hyperparams["learning_rate"], 267 1 - self.lr_schedule.decay, 268 ) 269 270 noise_std = optax.incremental_update( 271 self.noise_std_schedule.final, 272 state.noise_std, 273 1 - self.noise_std_schedule.decay, 274 ) 275 276 return PyTreeDict(), state.replace( 277 mean=mean, opt_state=opt_state, noise_std=noise_std, noise=None 278 )