Source code for evorl.ec.optimizers.evox_wrapper

 1import chex
 2from typing import Any, TYPE_CHECKING
 3
 4if TYPE_CHECKING:
 5    from evox import (
 6        State as EvoXState,
 7        Algorithm,
 8    )
 9else:
10    EvoXState = Any
11    Algorithm = Any
12
13from evorl.types import PyTreeData, pytree_field, PyTreeDict
14from evorl.utils.ec_utils import ParamVectorSpec
15
16from .ec_optimizer import EvoOptimizer
17
18
[docs] 19class EvoXAlgoState(PyTreeData): 20 algo_state: EvoXState 21 init_step: bool = pytree_field(static=True)
22 23
[docs] 24class EvoXAlgorithmAdapter(EvoOptimizer): 25 """Adapter class to convert EvoX algorithms to EvoRL optimizers.""" 26 27 algorithm: Algorithm 28 param_vec_spec: ParamVectorSpec 29
[docs] 30 def init(self, key: chex.PRNGKey) -> EvoXAlgoState: 31 algo_state = self.algorithm.init(key) 32 33 from evox import has_init_tell, has_init_ask 34 35 if has_init_tell(self.algorithm): 36 assert has_init_ask(self.algorithm) 37 init_step = True 38 else: 39 init_step = False 40 41 return EvoXAlgoState(algo_state=algo_state, init_step=init_step)
42
[docs] 43 def ask(self, state: EvoXAlgoState) -> tuple[chex.ArrayTree, EvoXAlgoState]: 44 from evox import has_init_ask 45 46 if has_init_ask(self.algorithm) and state.init_step: 47 ask = self.algorithm.init_ask 48 else: 49 ask = self.algorithm.ask 50 51 flat_pop, algo_state = ask(state.algo_state) 52 53 pop = self.param_vec_spec.to_tree(flat_pop) 54 55 return pop, state.replace(algo_state=algo_state)
56
[docs] 57 def tell( 58 self, state: EvoXAlgoState, fitnesses: chex.Array 59 ) -> tuple[PyTreeDict, EvoXAlgoState]: 60 from evox import has_init_tell 61 62 if has_init_tell(self.algorithm) and state.init_step: 63 tell = self.algorithm.init_tell 64 else: 65 tell = self.algorithm.tell 66 67 # Note: Evox's Algorithms minimize the fitness 68 algo_state = tell(state.algo_state, -fitnesses) 69 70 return PyTreeDict(), state.replace(algo_state=algo_state, init_step=False)