Source code for evorl.ec.operators.crossover.mlp_crossover

  1from functools import partial
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.types import PyTreeNode
  9from ..utils import is_layer_norm_layer
 10
 11
 12def mlp_crossover(
 13    x1: chex.ArrayTree,
 14    x2: chex.ArrayTree,
 15    key: chex.PRNGKey,
 16    *,
 17    num_crossover_frac: float = 0.1,
 18):
 19    chex.assert_trees_all_equal_shapes_and_dtypes(x1, x2)
 20
 21    leaves1, treedef = jtu.tree_flatten_with_path(x1)
 22    leaves2 = jtu.tree_leaves(x2)
 23
 24    params1 = []
 25    params2 = []
 26    for i, ((path, param1), param2) in enumerate(zip(leaves1, leaves2)):
 27        if is_layer_norm_layer(path):
 28            # skip layer norm layers
 29            params1.append(param1)
 30            params2.append(param2)
 31            continue
 32
 33        if param1.ndim <= 2:  # kernel
 34            # for 2d array, we exchange the rows
 35            # for 1d array, we exchange the elements
 36            key, ind_key, choice_key = jax.random.split(key, num=3)
 37
 38            # we use fixed number of crossover op: sampled without replacement
 39            # this is different from the original ERL: np.random.randint(num_variables * 2)
 40            num_crossover = round(param1.shape[0] * num_crossover_frac)
 41
 42            ind = jax.random.choice(
 43                ind_key, param1.shape[0], (num_crossover,), replace=False
 44            )
 45
 46            mask = jax.random.uniform(choice_key, (num_crossover,)) < 0.5
 47            if param1.ndim > 1:
 48                mask = mask[..., None]
 49
 50            zero_update = jnp.zeros((num_crossover, *param1.shape[1:]))
 51
 52            param1 = param1.at[ind].set(
 53                jnp.where(mask, zero_update, param2[ind]),
 54                unique_indices=True,
 55            )
 56
 57            param2 = param2.at[ind].set(
 58                jnp.where(jnp.logical_not(mask), zero_update, param1[ind]),
 59                unique_indices=True,
 60            )
 61
 62        else:
 63            raise ValueError(f"Unsupported parameter shape: {param1.shape}")
 64
 65        params1.append(param1)
 66        params2.append(param2)
 67
 68    return jtu.tree_unflatten(treedef, params1), jtu.tree_unflatten(treedef, params2)
 69
 70
[docs] 71class MLPCrossover(PyTreeNode): 72 num_crossover_frac: float = 0.1 73 74 def __post_init__(self): 75 assert self.num_crossover_frac >= 0 and self.num_crossover_frac <= 1, ( 76 "num_crossover_frac should be in [0, 1]" 77 ) 78 79 self.crossover_fn = jax.vmap( 80 partial(mlp_crossover, num_crossover_frac=self.num_crossover_frac), 81 ) 82 83 def __call__(self, xs: chex.ArrayTree, key: chex.PRNGKey): 84 pop_size = jtu.tree_leaves(xs)[0].shape[0] 85 assert pop_size % 2 == 0, "pop_size must be even" 86 # xs = jtu.tree_map(lambda p: p[:n], xs) 87 parents1 = jtu.tree_map(lambda x: x[0::2], xs) 88 parents2 = jtu.tree_map(lambda x: x[1::2], xs) 89 90 if key.ndim <= 1: 91 key = jax.random.split(key, pop_size // 2) 92 else: 93 chex.assert_shape( 94 key, 95 (pop_size, 2), 96 custom_message=f"Batched key shape {key.shape} must match pop_size: {pop_size}", 97 ) 98 99 offsprings1, offsprings2 = self.crossover_fn(parents1, parents2, key) 100 return jtu.tree_map( 101 lambda x1, x2: jnp.concatenate([x1, x2], axis=0), offsprings1, offsprings2 102 )