1from functools import partial
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.types import PyTreeNode
9from ..utils import is_layer_norm_layer
10
11
12def mlp_crossover(
13 x1: chex.ArrayTree,
14 x2: chex.ArrayTree,
15 key: chex.PRNGKey,
16 *,
17 num_crossover_frac: float = 0.1,
18):
19 chex.assert_trees_all_equal_shapes_and_dtypes(x1, x2)
20
21 leaves1, treedef = jtu.tree_flatten_with_path(x1)
22 leaves2 = jtu.tree_leaves(x2)
23
24 params1 = []
25 params2 = []
26 for i, ((path, param1), param2) in enumerate(zip(leaves1, leaves2)):
27 if is_layer_norm_layer(path):
28 # skip layer norm layers
29 params1.append(param1)
30 params2.append(param2)
31 continue
32
33 if param1.ndim <= 2: # kernel
34 # for 2d array, we exchange the rows
35 # for 1d array, we exchange the elements
36 key, ind_key, choice_key = jax.random.split(key, num=3)
37
38 # we use fixed number of crossover op: sampled without replacement
39 # this is different from the original ERL: np.random.randint(num_variables * 2)
40 num_crossover = round(param1.shape[0] * num_crossover_frac)
41
42 ind = jax.random.choice(
43 ind_key, param1.shape[0], (num_crossover,), replace=False
44 )
45
46 mask = jax.random.uniform(choice_key, (num_crossover,)) < 0.5
47 if param1.ndim > 1:
48 mask = mask[..., None]
49
50 zero_update = jnp.zeros((num_crossover, *param1.shape[1:]))
51
52 param1 = param1.at[ind].set(
53 jnp.where(mask, zero_update, param2[ind]),
54 unique_indices=True,
55 )
56
57 param2 = param2.at[ind].set(
58 jnp.where(jnp.logical_not(mask), zero_update, param1[ind]),
59 unique_indices=True,
60 )
61
62 else:
63 raise ValueError(f"Unsupported parameter shape: {param1.shape}")
64
65 params1.append(param1)
66 params2.append(param2)
67
68 return jtu.tree_unflatten(treedef, params1), jtu.tree_unflatten(treedef, params2)
69
70
[docs]
71class MLPCrossover(PyTreeNode):
72 num_crossover_frac: float = 0.1
73
74 def __post_init__(self):
75 assert self.num_crossover_frac >= 0 and self.num_crossover_frac <= 1, (
76 "num_crossover_frac should be in [0, 1]"
77 )
78
79 self.crossover_fn = jax.vmap(
80 partial(mlp_crossover, num_crossover_frac=self.num_crossover_frac),
81 )
82
83 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey):
84 pop_size = jtu.tree_leaves(xs)[0].shape[0]
85 assert pop_size % 2 == 0, "pop_size must be even"
86 # xs = jtu.tree_map(lambda p: p[:n], xs)
87 parents1 = jtu.tree_map(lambda x: x[0::2], xs)
88 parents2 = jtu.tree_map(lambda x: x[1::2], xs)
89
90 if key.ndim <= 1:
91 key = jax.random.split(key, pop_size // 2)
92 else:
93 chex.assert_shape(
94 key,
95 (pop_size, 2),
96 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}",
97 )
98
99 offsprings1, offsprings2 = self.crossover_fn(parents1, parents2, key)
100 return jtu.tree_map(
101 lambda x1, x2: jnp.concatenate([x1, x2], axis=0), offsprings1, offsprings2
102 )