Source code for evorl.ec.optimizers.evox_wrapper
1import chex
2from typing import Any, TYPE_CHECKING
3
4if TYPE_CHECKING:
5 from evox import (
6 State as EvoXState,
7 Algorithm,
8 )
9else:
10 EvoXState = Any
11 Algorithm = Any
12
13from evorl.types import PyTreeData, pytree_field, PyTreeDict
14from evorl.utils.ec_utils import ParamVectorSpec
15
16from .ec_optimizer import EvoOptimizer
17
18
[docs]
19class EvoXAlgoState(PyTreeData):
20 algo_state: EvoXState
21 init_step: bool = pytree_field(static=True)
22
23
[docs]
24class EvoXAlgorithmAdapter(EvoOptimizer):
25 """Adapter class to convert EvoX algorithms to EvoRL optimizers."""
26
27 algorithm: Algorithm
28 param_vec_spec: ParamVectorSpec
29
[docs]
30 def init(self, key: chex.PRNGKey) -> EvoXAlgoState:
31 algo_state = self.algorithm.init(key)
32
33 from evox import has_init_tell, has_init_ask
34
35 if has_init_tell(self.algorithm):
36 assert has_init_ask(self.algorithm)
37 init_step = True
38 else:
39 init_step = False
40
41 return EvoXAlgoState(algo_state=algo_state, init_step=init_step)
42
[docs]
43 def ask(self, state: EvoXAlgoState) -> tuple[chex.ArrayTree, EvoXAlgoState]:
44 from evox import has_init_ask
45
46 if has_init_ask(self.algorithm) and state.init_step:
47 ask = self.algorithm.init_ask
48 else:
49 ask = self.algorithm.ask
50
51 flat_pop, algo_state = ask(state.algo_state)
52
53 pop = self.param_vec_spec.to_tree(flat_pop)
54
55 return pop, state.replace(algo_state=algo_state)
56
[docs]
57 def tell(
58 self, state: EvoXAlgoState, fitnesses: chex.Array
59 ) -> tuple[PyTreeDict, EvoXAlgoState]:
60 from evox import has_init_tell
61
62 if has_init_tell(self.algorithm) and state.init_step:
63 tell = self.algorithm.init_tell
64 else:
65 tell = self.algorithm.tell
66
67 # Note: Evox's Algorithms minimize the fitness
68 algo_state = tell(state.algo_state, -fitnesses)
69
70 return PyTreeDict(), state.replace(algo_state=algo_state, init_step=False)