Source code for evorl.ec.operators.mutation.mlp_mutation

 1from functools import partial
 2
 3import chex
 4import jax
 5import jax.numpy as jnp
 6import jax.tree_util as jtu
 7
 8from evorl.types import PyTreeNode
 9from ..utils import is_layer_norm_layer
10
11
12def mlp_mutate(
13    x: chex.ArrayTree,
14    key: chex.PRNGKey,
15    *,
16    weight_max_magnitude: float = 10,
17    mut_strength: float = 0.01,
18    vector_num_mutation_frac: float = 0.0,
19    matrix_num_mutation_frac: float = 0.01,
20):
21    """Mutation for MLP.
22
23    Args:
24        key: PRNGKey
25        x: single individual,
26        vec_relative_prob: probability of mutating a vector(1-d) parameter.
27            Disable vector mutation when set 0.0; ERL use 0.04
28    """
29    leaves, treedef = jtu.tree_flatten_with_path(x)
30
31    def _mutate(param, key, num_mutation_frac):
32        num_mutations = round(num_mutation_frac * param.size)
33        key, ind_key, normal_update_key = jax.random.split(key, 3)
34        # unlike ERL, we sample elements without replacement
35        flat_ind = jax.random.choice(
36            ind_key, param.size, (num_mutations,), replace=False
37        )
38        ind = jnp.unravel_index(flat_ind, param.shape)
39        updates = jax.random.normal(normal_update_key, (num_mutations,)) * mut_strength
40        param = param.at[ind].set(param[ind] + updates, unique_indices=True)
41        param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
42        return param
43
44    params = []
45    for i, (path, param) in enumerate(leaves):
46        if is_layer_norm_layer(path):
47            params.append(param)
48            continue
49
50        if param.ndim == 2:  # kernel
51            param = _mutate(param, key, matrix_num_mutation_frac)
52        elif param.ndim == 1:  # bias or layer norm
53            if vector_num_mutation_frac > 0:
54                param = _mutate(param, key, vector_num_mutation_frac)
55        else:
56            raise ValueError(f"Unsupported parameter shape: {param.shape}")
57
58        params.append(param)
59
60    return jtu.tree_unflatten(treedef, params)
61
62
[docs] 63class MLPMutation(PyTreeNode): 64 weight_max_magnitude: float = 10 65 mut_strength: float = 0.01 66 vector_num_mutation_frac: float = 0.0 67 matrix_num_mutation_frac: float = 0.01 68 69 def __post_init__(self): 70 assert 0 <= self.vector_num_mutation_frac <= 1, ( 71 "vector_num_mutation_frac should be in [0, 1]" 72 ) 73 assert 0 <= self.matrix_num_mutation_frac <= 1, ( 74 "matrix_num_mutation_frac should be in [0, 1]" 75 ) 76 77 self.mutate_fn = jax.vmap( 78 partial( 79 mlp_mutate, 80 weight_max_magnitude=self.weight_max_magnitude, 81 mut_strength=self.mut_strength, 82 vector_num_mutation_frac=self.vector_num_mutation_frac, 83 matrix_num_mutation_frac=self.matrix_num_mutation_frac, 84 ), 85 ) 86 87 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey): 88 pop_size = jtu.tree_leaves(xs)[0].shape[0] 89 if key.ndim <= 1: 90 key = jax.random.split(key, pop_size) 91 else: 92 chex.assert_shape( 93 key, 94 (pop_size, 2), 95 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}", 96 ) 97 return self.mutate_fn(xs, key)