Source code for evorl.ec.optimizers.cem

  1import chex
  2import jax
  3import jax.numpy as jnp
  4import jax.tree_util as jtu
  5import optax
  6
  7from evorl.types import PyTreeData, PyTreeDict, Params, pytree_field
  8from evorl.utils.jax_utils import rng_split_like_tree
  9
 10from .utils import ExponentialScheduleSpec, weight_sum
 11from .ec_optimizer import EvoOptimizer, ECState
 12
 13
[docs] 14class SepCEMState(PyTreeData): 15 """State of the SepCEM.""" 16 17 mean: chex.ArrayTree 18 variance: chex.ArrayTree 19 cov_eps: chex.ArrayTree 20 key: chex.PRNGKey 21 pop: None | chex.ArrayTree = None
22 23
[docs] 24class SepCEM(EvoOptimizer): 25 """Sep Cross-Entropy Method.""" 26 27 pop_size: int 28 num_elites: int # number of good offspring to update the pop 29 cov_eps_schedule: ExponentialScheduleSpec 30 31 weighted_update: bool = True 32 rank_weight_shift: float = 1.0 33 mirror_sampling: bool = False 34 elite_weights: chex.Array = pytree_field(init=False) 35 36 def __post_init__(self): 37 assert self.pop_size > 0, "pop_size must be positive" 38 if self.mirror_sampling: 39 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling" 40 41 if self.weighted_update: 42 elite_weights = jnp.log(self.num_elites + self.rank_weight_shift) - jnp.log( 43 jnp.arange(1, self.num_elites + 1) 44 ) 45 else: 46 elite_weights = jnp.ones((self.num_elites,)) 47 48 self.elite_weights = elite_weights / elite_weights.sum() 49
[docs] 50 def init(self, mean: Params, key: chex.PRNGKey) -> SepCEMState: 51 variance = jtu.tree_map( 52 lambda x: jnp.full_like(x, self.cov_eps_schedule.init), mean 53 ) 54 55 return SepCEMState( 56 mean=mean, 57 variance=variance, 58 cov_eps=jnp.float32(self.cov_eps_schedule.init), 59 key=key, 60 )
61
[docs] 62 def ask(self, state: SepCEMState) -> tuple[chex.ArrayTree, ECState]: 63 key, sample_key = jax.random.split(state.key) 64 sample_keys = rng_split_like_tree(sample_key, state.mean) 65 66 if self.mirror_sampling: 67 half_noise = jtu.tree_map( 68 lambda x, var, k: jax.random.normal(k, (self.pop_size // 2, *x.shape)) 69 * jnp.sqrt(var), 70 state.mean, 71 state.variance, 72 sample_keys, 73 ) 74 75 noise = jtu.tree_map( 76 lambda x: jnp.concatenate([x, -x], axis=0), 77 half_noise, 78 ) 79 80 else: 81 noise = jtu.tree_map( 82 lambda x, var, k: jax.random.normal(k, (self.pop_size, *x.shape)) 83 * jnp.sqrt(var), 84 state.mean, 85 state.variance, 86 sample_keys, 87 ) 88 89 # noise: (#pop, ...) 90 # mean: (...) 91 92 pop = jtu.tree_map(lambda mean, noise: mean + noise, state.mean, noise) 93 state = state.replace(key=key, pop=pop) 94 95 return pop, state
96
[docs] 97 def tell( 98 self, state: SepCEMState, fitnesses: chex.Array 99 ) -> tuple[PyTreeDict, SepCEMState]: 100 # fitness: episode_return, higher is better 101 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1] 102 103 mean = jtu.tree_map( 104 lambda x: weight_sum(x[elite_indices], self.elite_weights), 105 state.pop, 106 ) 107 108 def var_update(m, x): 109 x_norm = jnp.square(x[elite_indices] - m) 110 # TODO: do we need extra division by num_elites mentioned in CEM-RL? 111 return weight_sum(x_norm, self.elite_weights) + state.cov_eps 112 113 variance = jtu.tree_map( 114 var_update, 115 state.mean, # old mean 116 state.pop, 117 ) 118 119 cov_eps = optax.incremental_update( 120 self.cov_eps_schedule.final, state.cov_eps, self.cov_eps_schedule.decay 121 ) 122 123 return PyTreeDict(), state.replace( 124 mean=mean, variance=variance, cov_eps=cov_eps, pop=None 125 )