1from functools import partial
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.types import PyTreeNode
9from ..utils import is_layer_norm_layer
10
11
12def mlp_mutate(
13 x: chex.ArrayTree,
14 key: chex.PRNGKey,
15 *,
16 weight_max_magnitude: float = 10,
17 mut_strength: float = 0.01,
18 vector_num_mutation_frac: float = 0.0,
19 matrix_num_mutation_frac: float = 0.01,
20):
21 """Mutation for MLP.
22
23 Args:
24 key: PRNGKey
25 x: single individual,
26 vec_relative_prob: probability of mutating a vector(1-d) parameter.
27 Disable vector mutation when set 0.0; ERL use 0.04
28 """
29 leaves, treedef = jtu.tree_flatten_with_path(x)
30
31 def _mutate(param, key, num_mutation_frac):
32 num_mutations = round(num_mutation_frac * param.size)
33 key, ind_key, normal_update_key = jax.random.split(key, 3)
34 # unlike ERL, we sample elements without replacement
35 flat_ind = jax.random.choice(
36 ind_key, param.size, (num_mutations,), replace=False
37 )
38 ind = jnp.unravel_index(flat_ind, param.shape)
39 updates = jax.random.normal(normal_update_key, (num_mutations,)) * mut_strength
40 param = param.at[ind].set(param[ind] + updates, unique_indices=True)
41 param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
42 return param
43
44 params = []
45 for i, (path, param) in enumerate(leaves):
46 if is_layer_norm_layer(path):
47 params.append(param)
48 continue
49
50 if param.ndim == 2: # kernel
51 param = _mutate(param, key, matrix_num_mutation_frac)
52 elif param.ndim == 1: # bias or layer norm
53 if vector_num_mutation_frac > 0:
54 param = _mutate(param, key, vector_num_mutation_frac)
55 else:
56 raise ValueError(f"Unsupported parameter shape: {param.shape}")
57
58 params.append(param)
59
60 return jtu.tree_unflatten(treedef, params)
61
62
[docs]
63class MLPMutation(PyTreeNode):
64 weight_max_magnitude: float = 10
65 mut_strength: float = 0.01
66 vector_num_mutation_frac: float = 0.0
67 matrix_num_mutation_frac: float = 0.01
68
69 def __post_init__(self):
70 assert 0 <= self.vector_num_mutation_frac <= 1, (
71 "vector_num_mutation_frac should be in [0, 1]"
72 )
73 assert 0 <= self.matrix_num_mutation_frac <= 1, (
74 "matrix_num_mutation_frac should be in [0, 1]"
75 )
76
77 self.mutate_fn = jax.vmap(
78 partial(
79 mlp_mutate,
80 weight_max_magnitude=self.weight_max_magnitude,
81 mut_strength=self.mut_strength,
82 vector_num_mutation_frac=self.vector_num_mutation_frac,
83 matrix_num_mutation_frac=self.matrix_num_mutation_frac,
84 ),
85 )
86
87 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey):
88 pop_size = jtu.tree_leaves(xs)[0].shape[0]
89 if key.ndim <= 1:
90 key = jax.random.split(key, pop_size)
91 else:
92 chex.assert_shape(
93 key,
94 (pop_size, 2),
95 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}",
96 )
97 return self.mutate_fn(xs, key)