1from collections.abc import Callable
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7import optax
8
9from evorl.types import PyTreeData, pytree_field, Params, PyTreeDict
10from evorl.utils.jax_utils import rng_split_like_tree, invert_permutation
11from evorl.utils.ec_utils import ParamVectorSpec
12
13from .utils import ExponentialScheduleSpec, weight_sum, optimizer_map
14from .ec_optimizer import EvoOptimizer, ECState
15
16
[docs]
17def compute_ranks(x):
18 """Get ranks in [0, len(x)-1].
19
20 This is different from `scipy.stats.rankdata`, which returns ranks in [1, len(x)].
21 """
22 assert x.ndim == 1
23 ranks = invert_permutation(jnp.argsort(x))
24 return ranks
25
26
[docs]
27def compute_centered_ranks(x):
28 """Get centered ranks in [-0.5, 0.5]."""
29 y = compute_ranks(x)
30 y /= x.size - 1
31 y -= 0.5
32 return y
33
34
[docs]
35class OpenESState(PyTreeData):
36 """State of the OpenES."""
37
38 mean: chex.ArrayTree
39 opt_state: optax.OptState
40 noise_std: chex.Array
41 key: chex.PRNGKey
42 noise: None | chex.ArrayTree = None
43
44
[docs]
45class OpenES(EvoOptimizer):
46 """OpenAI ES."""
47
48 pop_size: int
49 lr_schedule: ExponentialScheduleSpec
50 noise_std_schedule: ExponentialScheduleSpec
51 mirror_sampling: bool = True
52 optimizer_name: str = "adam"
53 weight_decay: float | None = None
54
55 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field(
56 static=True, default=compute_centered_ranks
57 )
58 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False)
59
60 def __post_init__(self):
61 assert self.pop_size > 0, "pop_size must be positive"
62 if self.mirror_sampling:
63 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling"
64
65 self.optimizer = optax.inject_hyperparams(optimizer_map[self.optimizer_name])(
66 learning_rate=self.lr_schedule.init
67 )
68
[docs]
69 def init(self, mean: Params, key: chex.PRNGKey) -> ECState:
70 return OpenESState(
71 mean=mean,
72 opt_state=self.optimizer.init(mean),
73 noise_std=jnp.float32(self.noise_std_schedule.init),
74 key=key,
75 )
76
[docs]
77 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]:
78 """Generate new candidate solutions."""
79 key, sample_key = jax.random.split(state.key)
80 sample_keys = rng_split_like_tree(sample_key, state.mean)
81
82 if self.mirror_sampling:
83 noise = jtu.tree_map(
84 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)),
85 state.mean,
86 sample_keys,
87 )
88 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise)
89 else:
90 noise = jtu.tree_map(
91 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape)),
92 state.mean,
93 sample_keys,
94 )
95
96 pop = jtu.tree_map(
97 lambda m, z: m + state.noise_std * z,
98 state.mean,
99 noise,
100 )
101 state = state.replace(key=key, noise=noise)
102
103 return pop, state
104
[docs]
105 def tell(
106 self, state: ECState, fitnesses: chex.Array
107 ) -> tuple[PyTreeDict, OpenESState]:
108 """Update the optimizer state based on the fitnesses of the candidate solutions."""
109 transformed_fitnesses = self.fitness_shaping_fn(fitnesses)
110
111 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m))
112 grad = jtu.tree_map(
113 # Note: we need additional "-1.0" since we are maximizing the fitness
114 lambda z: (
115 -weight_sum(z, transformed_fitnesses)
116 / (self.pop_size * state.noise_std)
117 ),
118 state.noise,
119 )
120
121 # add L2 weight decay
122 if self.weight_decay is not None:
123 grad = jtu.tree_map(
124 lambda g, x: g + self.weight_decay * x,
125 grad,
126 state.mean,
127 )
128
129 update, opt_state = self.optimizer.update(grad, state.opt_state)
130 mean = optax.apply_updates(state.mean, update)
131
132 opt_state.hyperparams["learning_rate"] = optax.incremental_update(
133 self.lr_schedule.final,
134 opt_state.hyperparams["learning_rate"],
135 1 - self.lr_schedule.decay,
136 )
137
138 noise_std = optax.incremental_update(
139 self.noise_std_schedule.final,
140 state.noise_std,
141 1 - self.noise_std_schedule.decay,
142 )
143
144 return PyTreeDict(), state.replace(
145 mean=mean, opt_state=opt_state, noise_std=noise_std, noise=None
146 )
147
148
[docs]
149class OpenESNoiseTableState(PyTreeData):
150 """State of the OpenES with noise table."""
151
152 mean: chex.ArrayTree
153 opt_state: optax.OptState
154 noise_std: chex.Array
155 noise_table: chex.ArrayTree
156 key: chex.PRNGKey
157 noise: None | chex.ArrayTree = None
158
159
[docs]
160class OpenESNoiseTable(EvoOptimizer):
161 """OpenAI ES with noise table."""
162
163 pop_size: int
164 noise_table_size: int
165 lr_schedule: ExponentialScheduleSpec
166 noise_std_schedule: ExponentialScheduleSpec
167 mirror_sampling: bool = True
168 optimizer_name: str = "adam"
169 weight_decay: float | None = None
170
171 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field(
172 static=True, default=compute_centered_ranks
173 )
174 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False)
175
176 def __post_init__(self):
177 assert self.pop_size > 0, "pop_size must be positive"
178 if self.mirror_sampling:
179 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling"
180
181 self.optimizer = optax.inject_hyperparams(optimizer_map[self.optimizer_name])(
182 learning_rate=self.lr_schedule.init
183 )
184
[docs]
185 def init(self, mean: Params, key: chex.PRNGKey) -> ECState:
186 key, noise_table_key = jax.random.split(key)
187 noise_table = jax.random.normal(noise_table_key, shape=(self.noise_table_size,))
188
189 return OpenESNoiseTableState(
190 mean=mean,
191 opt_state=self.optimizer.init(mean),
192 noise_std=jnp.float32(self.noise_std_schedule.init),
193 noise_table=noise_table,
194 key=key,
195 )
196
[docs]
197 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]:
198 """Generate new candidate solutions."""
199 key, sample_key = jax.random.split(state.key)
200 # sample_keys = rng_split_like_tree(sample_key, state.mean)
201
202 param_vec_spec = ParamVectorSpec(state.mean)
203
204 def sample_from_noise_table(idx):
205 return jax.lax.dynamic_slice_in_dim(
206 state.noise_table, idx, param_vec_spec.vec_size, axis=0
207 )
208
209 if self.mirror_sampling:
210 noise_idx = jax.random.randint(
211 sample_key,
212 shape=(self.pop_size // 2,),
213 minval=0,
214 maxval=self.noise_table_size - param_vec_spec.vec_size,
215 )
216 noise = param_vec_spec.to_tree(jax.vmap(sample_from_noise_table)(noise_idx))
217
218 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise)
219 else:
220 noise_idx = jax.random.randint(
221 sample_key,
222 shape=(self.pop_size,),
223 minval=0,
224 maxval=self.noise_table_size - param_vec_spec.vec_size,
225 )
226 noise = param_vec_spec.to_tree(jax.vmap(sample_from_noise_table)(noise_idx))
227
228 pop = jtu.tree_map(
229 lambda m, z: m + state.noise_std * z,
230 state.mean,
231 noise,
232 )
233 state = state.replace(key=key, noise=noise)
234
235 return pop, state
236
[docs]
237 def tell(
238 self, state: ECState, fitnesses: chex.Array
239 ) -> tuple[PyTreeDict, OpenESState]:
240 """Update the optimizer state based on the fitnesses of the candidate solutions."""
241 transformed_fitnesses = self.fitness_shaping_fn(fitnesses)
242
243 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m))
244 grad = jtu.tree_map(
245 # Note: we need additional "-1.0" since we are maximizing the fitness
246 lambda z: (
247 -weight_sum(z, transformed_fitnesses)
248 / (self.pop_size * state.noise_std)
249 ),
250 state.noise,
251 )
252
253 # add L2 weight decay
254 if self.weight_decay is not None:
255 grad = jtu.tree_map(
256 lambda g, x: g + self.weight_decay * x,
257 grad,
258 state.mean,
259 )
260
261 update, opt_state = self.optimizer.update(grad, state.opt_state)
262 mean = optax.apply_updates(state.mean, update)
263
264 opt_state.hyperparams["learning_rate"] = optax.incremental_update(
265 self.lr_schedule.final,
266 opt_state.hyperparams["learning_rate"],
267 1 - self.lr_schedule.decay,
268 )
269
270 noise_std = optax.incremental_update(
271 self.noise_std_schedule.final,
272 state.noise_std,
273 1 - self.noise_std_schedule.decay,
274 )
275
276 return PyTreeDict(), state.replace(
277 mean=mean, opt_state=opt_state, noise_std=noise_std, noise=None
278 )