Source code for evorl.ec.operators.mutation.erl_mutation

  1from functools import partial
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.types import PyTreeNode
  9from ..utils import is_layer_norm_layer
 10
 11
 12def erl_mutate(
 13    x: chex.ArrayTree,
 14    key: chex.PRNGKey,
 15    *,
 16    weight_max_magnitude: float = 1e6,
 17    mut_strength: float = 0.1,
 18    num_mutation_frac: float = 0.1,
 19    super_mut_strength: float = 10.0,
 20    super_mut_prob: float = 0.05,
 21    reset_prob: float = 0.05,
 22    vec_relative_prob: float = 0.0,
 23):
 24    """Mutation used in the original ERL for MLP.
 25
 26    Args:
 27        key: PRNGKey
 28        x: single individual,
 29        vec_relative_prob: probability of mutating a vector(1-d) parameter.
 30            Disable vector mutation when set 0.0; ERL use 0.04
 31    """
 32    leaves, treedef = jtu.tree_flatten_with_path(x)
 33    key, ssne_key = jax.random.split(key)
 34
 35    # prob thresould of whether mutate a param
 36    ssne_probs = jax.random.uniform(ssne_key, (len(leaves),)) * 2
 37
 38    params = []
 39    for i, (path, param) in enumerate(leaves):
 40        if is_layer_norm_layer(path):
 41            params.append(param)
 42            continue
 43
 44        if param.ndim == 2:  # kernel
 45            # Note: We use fixed number of mutations for a param,
 46            # This is a little different from the original ERL
 47            num_mutations = round(num_mutation_frac * param.size)
 48
 49            (
 50                key,
 51                ind_key,
 52                prob_key,
 53                normal_update_key,
 54                reset_update_key,
 55                ssne_prob_key,
 56            ) = jax.random.split(key, 6)
 57
 58            # unlike ERL, we sample elements without replacement
 59            num_param = param.shape[0] * param.shape[1]
 60            flat_ind = jax.random.choice(
 61                ind_key, num_param, (num_mutations,), replace=False
 62            )
 63            ind = jnp.unravel_index(flat_ind, param.shape)
 64
 65            prob = jax.random.uniform(prob_key, (num_mutations,))
 66            super_mask = prob < super_mut_prob
 67            reset_mask = jnp.logical_and(
 68                prob >= super_mut_prob, prob < reset_prob + super_mut_prob
 69            )
 70
 71            updates = jax.random.normal(normal_update_key, (num_mutations,)) * jnp.abs(
 72                param[ind]
 73            )
 74            updates = jnp.where(
 75                super_mask, updates * super_mut_strength, updates * mut_strength
 76            )
 77
 78            reset_param = jax.random.normal(reset_update_key, (num_mutations,))
 79            new_param = param.at[ind].set(
 80                jnp.where(reset_mask, reset_param, param[ind] + updates),
 81                unique_indices=True,
 82            )
 83
 84            ssne_prob = jax.random.uniform(ssne_prob_key)
 85            param = jnp.where(ssne_prob < ssne_probs[i], new_param, param)
 86
 87            param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
 88
 89        elif param.ndim == 1:  # bias or layer norm
 90            if vec_relative_prob > 0:
 91                num_mutations = round(num_mutation_frac * param.size)
 92
 93                (
 94                    key,
 95                    ind_key,
 96                    prob_key,
 97                    normal_update_key,
 98                    reset_update_key,
 99                    ssne_prob_key,
100                ) = jax.random.split(key, 6)
101
102                ind = jax.random.choice(ind_key, param.shape[0], (num_mutations,))
103
104                prob = jax.random.uniform(prob_key, (num_mutations,))
105                super_mask = prob < super_mut_prob
106                reset_mask = jnp.logical_and(prob >= super_mut_prob, prob < reset_prob)
107
108                updates = jax.random.normal(
109                    normal_update_key, (num_mutations,)
110                ) * jnp.abs(param[ind])
111                updates = jnp.where(
112                    super_mask, updates * super_mut_strength, updates * mut_strength
113                )
114
115                reset_param = jax.random.normal(reset_update_key, (num_mutations,))
116                new_param = param.at[ind].set(
117                    jnp.where(reset_mask, reset_param, param[ind] + updates),
118                    unique_indices=True,
119                )
120
121                ssne_prob = jax.random.uniform(ssne_prob_key)
122                param = jnp.where(
123                    ssne_prob < ssne_probs[i] * vec_relative_prob, new_param, param
124                )
125
126                param = jnp.clip(param, -weight_max_magnitude, weight_max_magnitude)
127
128        else:
129            raise ValueError(f"Unsupported parameter shape: {param.shape}")
130
131        params.append(param)
132
133    return jtu.tree_unflatten(treedef, params)
134
135
[docs] 136class ERLMutation(PyTreeNode): 137 weight_max_magnitude: float = 1e6 138 mut_strength: float = 0.1 139 num_mutation_frac: float = 0.1 140 super_mut_strength: float = 10.0 141 super_mut_prob: float = 0.05 142 reset_prob: float = 0.05 143 vec_relative_prob: float = 0.0 144 145 def __post_init__(self): 146 assert self.num_mutation_frac >= 0 and self.num_mutation_frac <= 1, ( 147 "num_mutation_frac should be in [0, 1]" 148 ) 149 150 self.mutate_fn = jax.vmap( 151 partial( 152 erl_mutate, 153 weight_max_magnitude=self.weight_max_magnitude, 154 mut_strength=self.mut_strength, 155 num_mutation_frac=self.num_mutation_frac, 156 super_mut_strength=self.super_mut_strength, 157 super_mut_prob=self.super_mut_prob, 158 reset_prob=self.reset_prob, 159 vec_relative_prob=self.vec_relative_prob, 160 ), 161 ) 162 163 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey): 164 pop_size = jtu.tree_leaves(xs)[0].shape[0] 165 if key.ndim <= 1: 166 key = jax.random.split(key, pop_size) 167 else: 168 chex.assert_shape( 169 key, 170 (pop_size, 2), 171 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}", 172 ) 173 return self.mutate_fn(xs, key)