1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5import optax
6
7from evorl.types import PyTreeData, PyTreeDict, Params, pytree_field
8from evorl.utils.jax_utils import rng_split_like_tree
9
10from .utils import ExponentialScheduleSpec, weight_sum
11from .ec_optimizer import EvoOptimizer, ECState
12
13
[docs]
14class SepCEMState(PyTreeData):
15 """State of the SepCEM."""
16
17 mean: chex.ArrayTree
18 variance: chex.ArrayTree
19 cov_eps: chex.ArrayTree
20 key: chex.PRNGKey
21 pop: None | chex.ArrayTree = None
22
23
[docs]
24class SepCEM(EvoOptimizer):
25 """Sep Cross-Entropy Method."""
26
27 pop_size: int
28 num_elites: int # number of good offspring to update the pop
29 cov_eps_schedule: ExponentialScheduleSpec
30
31 weighted_update: bool = True
32 rank_weight_shift: float = 1.0
33 mirror_sampling: bool = False
34 elite_weights: chex.Array = pytree_field(init=False)
35
36 def __post_init__(self):
37 assert self.pop_size > 0, "pop_size must be positive"
38 if self.mirror_sampling:
39 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling"
40
41 if self.weighted_update:
42 elite_weights = jnp.log(self.num_elites + self.rank_weight_shift) - jnp.log(
43 jnp.arange(1, self.num_elites + 1)
44 )
45 else:
46 elite_weights = jnp.ones((self.num_elites,))
47
48 self.elite_weights = elite_weights / elite_weights.sum()
49
[docs]
50 def init(self, mean: Params, key: chex.PRNGKey) -> SepCEMState:
51 variance = jtu.tree_map(
52 lambda x: jnp.full_like(x, self.cov_eps_schedule.init), mean
53 )
54
55 return SepCEMState(
56 mean=mean,
57 variance=variance,
58 cov_eps=jnp.float32(self.cov_eps_schedule.init),
59 key=key,
60 )
61
[docs]
62 def ask(self, state: SepCEMState) -> tuple[chex.ArrayTree, ECState]:
63 key, sample_key = jax.random.split(state.key)
64 sample_keys = rng_split_like_tree(sample_key, state.mean)
65
66 if self.mirror_sampling:
67 half_noise = jtu.tree_map(
68 lambda x, var, k: jax.random.normal(k, (self.pop_size // 2, *x.shape))
69 * jnp.sqrt(var),
70 state.mean,
71 state.variance,
72 sample_keys,
73 )
74
75 noise = jtu.tree_map(
76 lambda x: jnp.concatenate([x, -x], axis=0),
77 half_noise,
78 )
79
80 else:
81 noise = jtu.tree_map(
82 lambda x, var, k: jax.random.normal(k, (self.pop_size, *x.shape))
83 * jnp.sqrt(var),
84 state.mean,
85 state.variance,
86 sample_keys,
87 )
88
89 # noise: (#pop, ...)
90 # mean: (...)
91
92 pop = jtu.tree_map(lambda mean, noise: mean + noise, state.mean, noise)
93 state = state.replace(key=key, pop=pop)
94
95 return pop, state
96
[docs]
97 def tell(
98 self, state: SepCEMState, fitnesses: chex.Array
99 ) -> tuple[PyTreeDict, SepCEMState]:
100 # fitness: episode_return, higher is better
101 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1]
102
103 mean = jtu.tree_map(
104 lambda x: weight_sum(x[elite_indices], self.elite_weights),
105 state.pop,
106 )
107
108 def var_update(m, x):
109 x_norm = jnp.square(x[elite_indices] - m)
110 # TODO: do we need extra division by num_elites mentioned in CEM-RL?
111 return weight_sum(x_norm, self.elite_weights) + state.cov_eps
112
113 variance = jtu.tree_map(
114 var_update,
115 state.mean, # old mean
116 state.pop,
117 )
118
119 cov_eps = optax.incremental_update(
120 self.cov_eps_schedule.final, state.cov_eps, self.cov_eps_schedule.decay
121 )
122
123 return PyTreeDict(), state.replace(
124 mean=mean, variance=variance, cov_eps=cov_eps, pop=None
125 )