1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5import optax
6
7from evorl.types import PyTreeData, Params, PyTreeDict, pytree_field
8from evorl.utils.jax_utils import rng_split_like_tree
9
10from .utils import weight_sum, ExponentialScheduleSpec
11from .ec_optimizer import EvoOptimizer
12
13
[docs]
14class VanillaESState(PyTreeData):
15 """State of the VanillaES."""
16
17 mean: chex.ArrayTree
18 noise_std: chex.Array
19 key: chex.PRNGKey
20 noise: None | chex.ArrayTree = None
21
22
[docs]
23class VanillaES(EvoOptimizer):
24 """Canonical Evolution Strategies.
25
26 Paper: [Back to basics: Benchmarking canonical evolution strategies for playing atari](https://arxiv.org/abs/1802.08842)
27 """
28
29 pop_size: int
30 num_elites: int
31 noise_std_schedule: ExponentialScheduleSpec
32 elite_weights: chex.Array = pytree_field(init=False)
33
34 def __post_init__(self):
35 elite_weights = jnp.log(self.num_elites + 0.5) - jnp.log(
36 jnp.arange(1, self.num_elites + 1)
37 )
38 self.elite_weights = elite_weights / elite_weights.sum()
39
[docs]
40 def init(self, mean: Params, key: chex.PRNGKey) -> VanillaESState:
41 return VanillaESState(
42 mean=mean,
43 noise_std=jnp.float32(self.noise_std_schedule.init),
44 key=key,
45 )
46
[docs]
47 def ask(self, state: VanillaESState) -> tuple[Params, VanillaESState]:
48 key, sample_key = jax.random.split(state.key)
49 sample_keys = rng_split_like_tree(sample_key, state.mean)
50
51 noise = jtu.tree_map(
52 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape))
53 * state.noise_std,
54 state.mean,
55 sample_keys,
56 )
57
58 pop = jtu.tree_map(
59 lambda m, z: m + z,
60 state.mean,
61 noise,
62 )
63 return pop, state.replace(key=key, noise=noise)
64
[docs]
65 def tell(
66 self, state: VanillaESState, fitnesses: chex.Array
67 ) -> tuple[PyTreeDict, VanillaESState]:
68 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1]
69
70 mean = jtu.tree_map(
71 lambda x, z: x + weight_sum(z[elite_indices], self.elite_weights),
72 state.mean,
73 state.noise,
74 )
75
76 noise_std = optax.incremental_update(
77 self.noise_std_schedule.final,
78 state.noise_std,
79 1 - self.noise_std_schedule.decay,
80 )
81
82 return PyTreeDict(), state.replace(mean=mean, noise_std=noise_std, noise=None)
83
84
[docs]
85class VanillaESMod(VanillaES):
86 """Variant of VanillaES.
87
88 Add `external_size` number of external individuals and corresponding fitnesses during the ES update by `tell_external()`
89
90 Attributes:
91 external_size: number of external individuals
92 mix_strategy: strategy to mix external individuals with the elites.
93 - "always": always mix external individuals with elites
94 - "normal": concat external individuals to the population and select `num_elites` elites from the combined population.
95 """
96
97 external_size: int
98 mix_strategy: str = "always"
99
100 def __post_init__(self):
101 super().__post_init__()
102 assert self.num_elites >= self.external_size
103 assert self.mix_strategy in ["always", "normal"]
104
[docs]
105 def tell_external(
106 self, state: VanillaESState, fitnesses: chex.Array
107 ) -> tuple[PyTreeDict, VanillaESState]:
108 chex.assert_shape(fitnesses, (self.pop_size + self.external_size,))
109 chex.assert_tree_shape_prefix(
110 state.noise, (self.pop_size + self.external_size,)
111 )
112
113 if self.mix_strategy == "always":
114 # select (self.num_elites-self.external_size) elites from pop
115 # then insert all external individuals and sort them.
116
117 # Note: user should ensure external individuals and fitnesses are concated behind the pop.
118 # TODO: need to improve
119 pop_fitnesses = fitnesses[: self.pop_size]
120 external_fitnesses = fitnesses[self.pop_size :]
121
122 pop_elite_fitnesses, pop_elite_indices = jax.lax.top_k(
123 pop_fitnesses, self.num_elites - self.external_size
124 )
125
126 elite_fitnesses = jnp.concatenate([pop_elite_fitnesses, external_fitnesses])
127 elite_indices = jnp.concatenate(
128 [
129 pop_elite_indices,
130 jnp.arange(
131 self.pop_size,
132 self.pop_size + self.external_size,
133 dtype=jnp.int32,
134 ),
135 ]
136 )
137
138 elite_indices = elite_indices[jnp.argsort(elite_fitnesses, descending=True)]
139 else:
140 elite_indices = jax.lax.top_k(fitnesses, self.num_elites)[1]
141
142 mean = jtu.tree_map(
143 lambda x, z: x + weight_sum(z[elite_indices], self.elite_weights),
144 state.mean,
145 state.noise,
146 )
147
148 noise_std = optax.incremental_update(
149 self.noise_std_schedule.final,
150 state.noise_std,
151 1 - self.noise_std_schedule.decay,
152 )
153
154 return PyTreeDict(), state.replace(mean=mean, noise_std=noise_std, noise=None)