1import math
2
3import chex
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.types import PyTreeData, PyTreeDict
9from evorl.ec.operators import ERLMutation, MLPCrossover, TournamentSelection
10from evorl.utils.jax_utils import tree_get
11
12from .ec_optimizer import EvoOptimizer
13
14
[docs]
15class ERLGAState(PyTreeData):
16 pop: chex.ArrayTree
17 key: chex.PRNGKey
18
19
[docs]
20class ERLGA(EvoOptimizer):
21 pop_size: int
22 num_elites: int
23
24 # selection
25 tournament_size: int = 3
26
27 # mutation
28 weight_max_magnitude: float = 1e6
29 mut_strength: float = 0.1
30 num_mutation_frac: float = 0.1
31 super_mut_strength: float = 10.0
32 super_mut_prob: float = 0.05
33 reset_prob: float = 0.1
34 vec_relative_prob: float = 0.0
35
36 # crossover
37 enable_crossover: bool = True
38 num_crossover_frac: float = 0.1
39
40 def __post_init__(self):
41 assert self.pop_size >= self.num_elites, "num_elites must be <= pop_size"
42 assert (
43 self.pop_size - self.num_elites
44 ) % 2 == 0 or not self.enable_crossover, (
45 "(pop_size - num_elites) must be even when enable crossover"
46 )
47 self.select_parents = TournamentSelection(tournament_size=self.tournament_size)
48 self.mutate = ERLMutation(
49 weight_max_magnitude=self.weight_max_magnitude,
50 mut_strength=self.mut_strength,
51 num_mutation_frac=self.num_mutation_frac,
52 super_mut_strength=self.super_mut_strength,
53 super_mut_prob=self.super_mut_prob,
54 reset_prob=self.reset_prob,
55 vec_relative_prob=self.vec_relative_prob,
56 )
57 if self.enable_crossover:
58 self.crossover = MLPCrossover(num_crossover_frac=self.num_crossover_frac)
59
[docs]
60 def init(self, pop, key) -> ERLGAState:
61 return ERLGAState(pop=pop, key=key)
62
[docs]
63 def ask(self, state: ERLGAState) -> tuple[chex.ArrayTree, ERLGAState]:
64 return state.pop, state
65
[docs]
66 def tell(
67 self, state: ERLGAState, fitnesses: chex.Array
68 ) -> tuple[PyTreeDict, ERLGAState]:
69 # Note: We simplify the update in ERL
70 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4)
71
72 elite_indices = jnp.argsort(fitnesses, descending=True)[: self.num_elites]
73 elites = tree_get(state.pop, elite_indices)
74
75 if self.enable_crossover:
76 real_num_parents = self.pop_size - self.num_elites
77 num_parents = math.ceil((real_num_parents) / 2) * 2
78 parents_indices = self.select_parents(fitnesses, num_parents, select_key)
79 parents = tree_get(state.pop, parents_indices)
80
81 offsprings = self.crossover(parents, crossover_key)
82 if real_num_parents % 2 != 0:
83 offsprings = tree_get(offsprings, slice(real_num_parents))
84 offsprings = self.mutate(offsprings, mutate_key)
85 else:
86 num_parents = self.pop_size - self.num_elites
87 parents_indices = self.select_parents(fitnesses, num_parents, select_key)
88 parents = tree_get(state.pop, parents_indices)
89 offsprings = self.mutate(parents, mutate_key)
90
91 new_pop = jtu.tree_map(
92 lambda x, y: jnp.concatenate([x, y], axis=0), elites, offsprings
93 )
94
95 return PyTreeDict(), state.replace(pop=new_pop, key=key)
96
97
[docs]
98class ERLGAModState(ERLGAState):
99 external_pop: None | chex.ArrayTree = None
100
101
[docs]
102class ERLGAMod(ERLGA):
103 external_size: int
104
105 def __post_init__(self):
106 assert self.pop_size >= (self.num_elites + self.external_size), (
107 "num_elites+external_size must be <= pop_size"
108 )
109 super().__post_init__()
110
[docs]
111 def init(self, pop, key) -> ERLGAModState:
112 return ERLGAModState(pop=pop, key=key)
113
[docs]
114 def tell_external(
115 self, state: ERLGAModState, fitnesses: chex.Array
116 ) -> tuple[PyTreeDict, ERLGAModState]:
117 # Note: We simplify the update in ERL
118 key, select_key, mutate_key, crossover_key = jax.random.split(state.key, 4)
119
120 sorted_indices = jnp.argsort(fitnesses, descending=True)
121 elite_indices = sorted_indices[: self.num_elites]
122 elites = tree_get(state.pop, elite_indices)
123
124 # unselected(worst) are replaced by external op (e.g: from RL)
125 # unselected_indices = sorted_indices[-self.external_size :]
126 unselected = state.external_pop
127
128 selected_indices = sorted_indices[: -self.external_size]
129
130 if self.enable_crossover:
131 real_num_parents = self.pop_size - self.num_elites - self.external_size
132 num_parents = math.ceil((real_num_parents) / 2) * 2
133 parents_indices = selected_indices[
134 self.select_parents(
135 fitnesses[selected_indices],
136 num_parents,
137 select_key,
138 )
139 ]
140 parents = tree_get(state.pop, parents_indices)
141 offsprings = self.crossover(parents, crossover_key)
142 if real_num_parents % 2 != 0:
143 offsprings = tree_get(offsprings, slice(real_num_parents))
144 offsprings = self.mutate(offsprings, mutate_key)
145 else:
146 num_parents = self.pop_size - self.num_elites - self.external_size
147 parents_indices = selected_indices[
148 self.select_parents(
149 fitnesses[selected_indices],
150 num_parents,
151 select_key,
152 )
153 ]
154 parents = tree_get(state.pop, parents_indices)
155 offsprings = self.mutate(parents, mutate_key)
156
157 new_pop = jtu.tree_map(
158 lambda x, y, z: jnp.concatenate([x, y, z], axis=0),
159 elites,
160 offsprings,
161 unselected,
162 )
163
164 return PyTreeDict(), state.replace(pop=new_pop, key=key, external_pop=None)