Source code for evorl.ec.optimizers.vanilla_ga
1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5
6from evorl.types import PyTreeData, PyTreeDict
7from evorl.ec.operators import MLPMutation, MLPCrossover, TournamentSelection
8
9from .ec_optimizer import EvoOptimizer
10
11
[docs]
12class VanillaGAState(PyTreeData):
13 """State of the VanillaGA."""
14
15 pop: chex.ArrayTree
16 key: chex.PRNGKey
17
18
[docs]
19class VanillaGA(EvoOptimizer):
20 """Vanilla Genetic Algorithm.
21
22 The Genetic Algorithm used in the original ERL.
23 Paper: [Evolution-Guided Policy Gradient in Reinforcement Learning](https://arxiv.org/abs/1805.07917)
24 """
25
26 pop_size: int
27 num_elites: int
28
29 # selection
30 tournament_size: int = 2
31
32 # mutation
33 weight_max_magnitude: float = 1e6
34 mut_strength: float = 0.1
35 vector_num_mutation_frac: float = 0.0
36 matrix_num_mutation_frac: float = 0.01
37
38 # crossover
39 enable_crossover: bool = True
40 num_crossover_frac: float = 0.1
41
42 def __post_init__(self):
43 assert (
44 self.pop_size - self.num_elites
45 ) % 2 == 0 or not self.enable_crossover, (
46 "(pop_size - num_elites) must be even when enable crossover"
47 )
48
49 self.select_parents = TournamentSelection(tournament_size=self.tournament_size)
50 self.mutate = MLPMutation(
51 weight_max_magnitude=self.weight_max_magnitude,
52 mut_strength=self.mut_strength,
53 vector_num_mutation_frac=self.vector_num_mutation_frac,
54 matrix_num_mutation_frac=self.matrix_num_mutation_frac,
55 )
56 if self.enable_crossover:
57 self.crossover = MLPCrossover(num_crossover_frac=self.num_crossover_frac)
58
[docs]
59 def init(self, pop: chex.ArrayTree, key: chex.PRNGKey) -> VanillaGAState:
60 return VanillaGAState(pop=pop, key=key)
61
[docs]
62 def ask(self, state: VanillaGAState) -> tuple[chex.ArrayTree, VanillaGAState]:
63 return state.pop, state
64
[docs]
65 def tell(
66 self, state: VanillaGAState, fitnesses: chex.Array
67 ) -> tuple[PyTreeDict, VanillaGAState]:
68 # Note: We simplify the update in ERL
69 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4)
70
71 elite_indices = jnp.argsort(fitnesses, descending=True)[: self.num_elites]
72 elites = jtu.tree_map(lambda x: x[elite_indices], state.pop)
73
74 parents_indices = self.select_parents(
75 fitnesses, self.pop_size - self.num_elites, select_key
76 )
77 parents = jtu.tree_map(lambda x: x[parents_indices], state.pop)
78
79 if self.enable_crossover:
80 offsprings = self.crossover(parents, crossover_key)
81 offsprings = self.mutate(offsprings, mutate_key)
82 else:
83 offsprings = self.mutate(parents, mutate_key)
84
85 new_pop = jtu.tree_map(
86 lambda x, y: jnp.concatenate([x, y], axis=0), elites, offsprings
87 )
88
89 return PyTreeDict(), state.replace(pop=new_pop, key=key)