1from functools import partial
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.types import PyTreeNode
9from ..utils import is_layer_norm_layer
10
11
12def erl_mutate(
13 x: chex.ArrayTree,
14 key: chex.PRNGKey,
15 *,
16 weight_max_magnitude: float = 1e6,
17 mut_strength: float = 0.1,
18 num_mutation_frac: float = 0.1,
19 super_mut_strength: float = 10.0,
20 super_mut_prob: float = 0.05,
21 reset_prob: float = 0.05,
22 vec_relative_prob: float = 0.0,
23):
24 """Mutation used in the original ERL for MLP.
25
26 Args:
27 key: PRNGKey
28 x: single individual,
29 vec_relative_prob: probability of mutating a vector(1-d) parameter.
30 Disable vector mutation when set 0.0; ERL use 0.04
31 """
32 leaves, treedef = jtu.tree_flatten_with_path(x)
33 key, ssne_key = jax.random.split(key)
34
35 # prob thresould of whether mutate a param
36 ssne_probs = jax.random.uniform(ssne_key, (len(leaves),)) * 2
37
38 params = []
39 for i, (path, param) in enumerate(leaves):
40 if is_layer_norm_layer(path):
41 params.append(param)
42 continue
43
44 if param.ndim == 2: # kernel
45 # Note: We use fixed number of mutations for a param,
46 # This is a little different from the original ERL
47 num_mutations = round(num_mutation_frac * param.size)
48
49 (
50 key,
51 ind_key,
52 prob_key,
53 normal_update_key,
54 reset_update_key,
55 ssne_prob_key,
56 ) = jax.random.split(key, 6)
57
58 # unlike ERL, we sample elements without replacement
59 num_param = param.shape[0] * param.shape[1]
60 flat_ind = jax.random.choice(
61 ind_key, num_param, (num_mutations,), replace=False
62 )
63 ind = jnp.unravel_index(flat_ind, param.shape)
64
65 prob = jax.random.uniform(prob_key, (num_mutations,))
66 super_mask = prob < super_mut_prob
67 reset_mask = jnp.logical_and(
68 prob >= super_mut_prob, prob < reset_prob + super_mut_prob
69 )
70
71 updates = jax.random.normal(normal_update_key, (num_mutations,)) * jnp.abs(
72 param[ind]
73 )
74 updates = jnp.where(
75 super_mask, updates * super_mut_strength, updates * mut_strength
76 )
77
78 reset_param = jax.random.normal(reset_update_key, (num_mutations,))
79 new_param = param.at[ind].set(
80 jnp.where(reset_mask, reset_param, param[ind] + updates),
81 unique_indices=True,
82 )
83
84 ssne_prob = jax.random.uniform(ssne_prob_key)
85 param = jnp.where(ssne_prob < ssne_probs[i], new_param, param)
86
87 param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
88
89 elif param.ndim == 1: # bias or layer norm
90 if vec_relative_prob > 0:
91 num_mutations = round(num_mutation_frac * param.size)
92
93 (
94 key,
95 ind_key,
96 prob_key,
97 normal_update_key,
98 reset_update_key,
99 ssne_prob_key,
100 ) = jax.random.split(key, 6)
101
102 ind = jax.random.choice(ind_key, param.shape[0], (num_mutations,))
103
104 prob = jax.random.uniform(prob_key, (num_mutations,))
105 super_mask = prob < super_mut_prob
106 reset_mask = jnp.logical_and(prob >= super_mut_prob, prob < reset_prob)
107
108 updates = jax.random.normal(
109 normal_update_key, (num_mutations,)
110 ) * jnp.abs(param[ind])
111 updates = jnp.where(
112 super_mask, updates * super_mut_strength, updates * mut_strength
113 )
114
115 reset_param = jax.random.normal(reset_update_key, (num_mutations,))
116 new_param = param.at[ind].set(
117 jnp.where(reset_mask, reset_param, param[ind] + updates),
118 unique_indices=True,
119 )
120
121 ssne_prob = jax.random.uniform(ssne_prob_key)
122 param = jnp.where(
123 ssne_prob < ssne_probs[i] * vec_relative_prob, new_param, param
124 )
125
126 param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
127
128 else:
129 raise ValueError(f"Unsupported parameter shape: {param.shape}")
130
131 params.append(param)
132
133 return jtu.tree_unflatten(treedef, params)
134
135
[docs]
136class ERLMutation(PyTreeNode):
137 weight_max_magnitude: float = 1e6
138 mut_strength: float = 0.1
139 num_mutation_frac: float = 0.1
140 super_mut_strength: float = 10.0
141 super_mut_prob: float = 0.05
142 reset_prob: float = 0.05
143 vec_relative_prob: float = 0.0
144
145 def __post_init__(self):
146 assert self.num_mutation_frac >= 0 and self.num_mutation_frac <= 1, (
147 "num_mutation_frac should be in [0, 1]"
148 )
149
150 self.mutate_fn = jax.vmap(
151 partial(
152 erl_mutate,
153 weight_max_magnitude=self.weight_max_magnitude,
154 mut_strength=self.mut_strength,
155 num_mutation_frac=self.num_mutation_frac,
156 super_mut_strength=self.super_mut_strength,
157 super_mut_prob=self.super_mut_prob,
158 reset_prob=self.reset_prob,
159 vec_relative_prob=self.vec_relative_prob,
160 ),
161 )
162
163 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey):
164 pop_size = jtu.tree_leaves(xs)[0].shape[0]
165 if key.ndim <= 1:
166 key = jax.random.split(key, pop_size)
167 else:
168 chex.assert_shape(
169 key,
170 (pop_size, 2),
171 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}",
172 )
173 return self.mutate_fn(xs, key)