Source code for evorl.ec.optimizers.erl_ga

  1import math
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.types import PyTreeData, PyTreeDict
  9from evorl.ec.operators import ERLMutation, MLPCrossover, TournamentSelection
 10from evorl.utils.jax_utils import tree_get
 11
 12from .ec_optimizer import EvoOptimizer
 13
 14
[docs] 15class ERLGAState(PyTreeData): 16 pop: chex.ArrayTree 17 key: chex.PRNGKey
18 19
[docs] 20class ERLGA(EvoOptimizer): 21 pop_size: int 22 num_elites: int 23 24 # selection 25 tournament_size: int = 3 26 27 # mutation 28 weight_max_magnitude: float = 1e6 29 mut_strength: float = 0.1 30 num_mutation_frac: float = 0.1 31 super_mut_strength: float = 10.0 32 super_mut_prob: float = 0.05 33 reset_prob: float = 0.1 34 vec_relative_prob: float = 0.0 35 36 # crossover 37 enable_crossover: bool = True 38 num_crossover_frac: float = 0.1 39 40 def __post_init__(self): 41 assert self.pop_size >= self.num_elites, "num_elites must be <= pop_size" 42 assert ( 43 self.pop_size - self.num_elites 44 ) % 2 == 0 or not self.enable_crossover, ( 45 "(pop_size - num_elites) must be even when enable crossover" 46 ) 47 self.select_parents = TournamentSelection(tournament_size=self.tournament_size) 48 self.mutate = ERLMutation( 49 weight_max_magnitude=self.weight_max_magnitude, 50 mut_strength=self.mut_strength, 51 num_mutation_frac=self.num_mutation_frac, 52 super_mut_strength=self.super_mut_strength, 53 super_mut_prob=self.super_mut_prob, 54 reset_prob=self.reset_prob, 55 vec_relative_prob=self.vec_relative_prob, 56 ) 57 if self.enable_crossover: 58 self.crossover = MLPCrossover(num_crossover_frac=self.num_crossover_frac) 59
[docs] 60 def init(self, pop, key) -> ERLGAState: 61 return ERLGAState(pop=pop, key=key)
62
[docs] 63 def ask(self, state: ERLGAState) -> tuple[chex.ArrayTree, ERLGAState]: 64 return state.pop, state
65
[docs] 66 def tell( 67 self, state: ERLGAState, fitnesses: chex.Array 68 ) -> tuple[PyTreeDict, ERLGAState]: 69 # Note: We simplify the update in ERL 70 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4) 71 72 elite_indices = jnp.argsort(fitnesses, descending=True)[: self.num_elites] 73 elites = tree_get(state.pop, elite_indices) 74 75 if self.enable_crossover: 76 real_num_parents = self.pop_size - self.num_elites 77 num_parents = math.ceil((real_num_parents) / 2) * 2 78 parents_indices = self.select_parents(fitnesses, num_parents, select_key) 79 parents = tree_get(state.pop, parents_indices) 80 81 offsprings = self.crossover(parents, crossover_key) 82 if real_num_parents % 2 != 0: 83 offsprings = tree_get(offsprings, slice(real_num_parents)) 84 offsprings = self.mutate(offsprings, mutate_key) 85 else: 86 num_parents = self.pop_size - self.num_elites 87 parents_indices = self.select_parents(fitnesses, num_parents, select_key) 88 parents = tree_get(state.pop, parents_indices) 89 offsprings = self.mutate(parents, mutate_key) 90 91 new_pop = jtu.tree_map( 92 lambda x, y: jnp.concatenate([x, y], axis=0), elites, offsprings 93 ) 94 95 return PyTreeDict(), state.replace(pop=new_pop, key=key)
96 97
[docs] 98class ERLGAModState(ERLGAState): 99 external_pop: None | chex.ArrayTree = None
100 101
[docs] 102class ERLGAMod(ERLGA): 103 external_size: int 104 105 def __post_init__(self): 106 assert self.pop_size >= (self.num_elites + self.external_size), ( 107 "num_elites+external_size must be <= pop_size" 108 ) 109 super().__post_init__() 110
[docs] 111 def init(self, pop, key) -> ERLGAModState: 112 return ERLGAModState(pop=pop, key=key)
113
[docs] 114 def tell_external( 115 self, state: ERLGAModState, fitnesses: chex.Array 116 ) -> tuple[PyTreeDict, ERLGAModState]: 117 # Note: We simplify the update in ERL 118 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4) 119 120 sorted_indices = jnp.argsort(fitnesses, descending=True) 121 elite_indices = sorted_indices[: self.num_elites] 122 elites = tree_get(state.pop, elite_indices) 123 124 # unselected(worst) are replaced by external op (e.g: from RL) 125 # unselected_indices = sorted_indices[-self.external_size :] 126 unselected = state.external_pop 127 128 selected_indices = sorted_indices[: -self.external_size] 129 130 if self.enable_crossover: 131 real_num_parents = self.pop_size - self.num_elites - self.external_size 132 num_parents = math.ceil((real_num_parents) / 2) * 2 133 parents_indices = selected_indices[ 134 self.select_parents( 135 fitnesses[selected_indices], 136 num_parents, 137 select_key, 138 ) 139 ] 140 parents = tree_get(state.pop, parents_indices) 141 offsprings = self.crossover(parents, crossover_key) 142 if real_num_parents % 2 != 0: 143 offsprings = tree_get(offsprings, slice(real_num_parents)) 144 offsprings = self.mutate(offsprings, mutate_key) 145 else: 146 num_parents = self.pop_size - self.num_elites - self.external_size 147 parents_indices = selected_indices[ 148 self.select_parents( 149 fitnesses[selected_indices], 150 num_parents, 151 select_key, 152 ) 153 ] 154 parents = tree_get(state.pop, parents_indices) 155 offsprings = self.mutate(parents, mutate_key) 156 157 new_pop = jtu.tree_map( 158 lambda x, y, z: jnp.concatenate([x, y, z], axis=0), 159 elites, 160 offsprings, 161 unselected, 162 ) 163 164 return PyTreeDict(), state.replace(pop=new_pop, key=key, external_pop=None)