Source code for evorl.ec.optimizers.vanilla_ga

 1import chex
 2import jax
 3import jax.numpy as jnp
 4import jax.tree_util as jtu
 5
 6from evorl.types import PyTreeData, PyTreeDict
 7from evorl.ec.operators import MLPMutation, MLPCrossover, TournamentSelection
 8
 9from .ec_optimizer import EvoOptimizer
10
11
[docs] 12class VanillaGAState(PyTreeData): 13 """State of the VanillaGA.""" 14 15 pop: chex.ArrayTree 16 key: chex.PRNGKey
17 18
[docs] 19class VanillaGA(EvoOptimizer): 20 """Vanilla Genetic Algorithm. 21 22 The Genetic Algorithm used in the original ERL. 23 Paper: [Evolution-Guided Policy Gradient in Reinforcement Learning](https://arxiv.org/abs/1805.07917) 24 """ 25 26 pop_size: int 27 num_elites: int 28 29 # selection 30 tournament_size: int = 2 31 32 # mutation 33 weight_max_magnitude: float = 1e6 34 mut_strength: float = 0.1 35 vector_num_mutation_frac: float = 0.0 36 matrix_num_mutation_frac: float = 0.01 37 38 # crossover 39 enable_crossover: bool = True 40 num_crossover_frac: float = 0.1 41 42 def __post_init__(self): 43 assert ( 44 self.pop_size - self.num_elites 45 ) % 2 == 0 or not self.enable_crossover, ( 46 "(pop_size - num_elites) must be even when enable crossover" 47 ) 48 49 self.select_parents = TournamentSelection(tournament_size=self.tournament_size) 50 self.mutate = MLPMutation( 51 weight_max_magnitude=self.weight_max_magnitude, 52 mut_strength=self.mut_strength, 53 vector_num_mutation_frac=self.vector_num_mutation_frac, 54 matrix_num_mutation_frac=self.matrix_num_mutation_frac, 55 ) 56 if self.enable_crossover: 57 self.crossover = MLPCrossover(num_crossover_frac=self.num_crossover_frac) 58
[docs] 59 def init(self, pop: chex.ArrayTree, key: chex.PRNGKey) -> VanillaGAState: 60 return VanillaGAState(pop=pop, key=key)
61
[docs] 62 def ask(self, state: VanillaGAState) -> tuple[chex.ArrayTree, VanillaGAState]: 63 return state.pop, state
64
[docs] 65 def tell( 66 self, state: VanillaGAState, fitnesses: chex.Array 67 ) -> tuple[PyTreeDict, VanillaGAState]: 68 # Note: We simplify the update in ERL 69 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4) 70 71 elite_indices = jnp.argsort(fitnesses, descending=True)[: self.num_elites] 72 elites = jtu.tree_map(lambda x: x[elite_indices], state.pop) 73 74 parents_indices = self.select_parents( 75 fitnesses, self.pop_size - self.num_elites, select_key 76 ) 77 parents = jtu.tree_map(lambda x: x[parents_indices], state.pop) 78 79 if self.enable_crossover: 80 offsprings = self.crossover(parents, crossover_key) 81 offsprings = self.mutate(offsprings, mutate_key) 82 else: 83 offsprings = self.mutate(parents, mutate_key) 84 85 new_pop = jtu.tree_map( 86 lambda x, y: jnp.concatenate([x, y], axis=0), elites, offsprings 87 ) 88 89 return PyTreeDict(), state.replace(pop=new_pop, key=key)