Source code for evorl.algorithms.meta.pbt
1import logging
2
3import chex
4import optax
5from optax.schedules import InjectStatefulHyperparamsState
6
7from evorl.types import PyTreeDict, State
8from evorl.utils.jax_utils import tree_deepcopy
9
10from .pbt_workflow import PBTWorkflowTemplate, PBTOptState
11from .pbt_utils import log_uniform_init
12
13logger = logging.getLogger(__name__)
14
15
[docs]
16class PBTWorkflow(PBTWorkflowTemplate):
17 """A minimal Example of PBT that tunes the lr of PPO."""
18
[docs]
19 @classmethod
20 def name(cls):
21 return "PBT"
22
23 def _customize_optimizer(self) -> None:
24 """Customize the target workflow's optimizer."""
25 self.workflow.optimizer = optax.inject_hyperparams(
26 optax.adam, static_args=("b1", "b2", "eps", "eps_root")
27 )(learning_rate=self.config.search_space.lr.low)
28
29 def _setup_pop_and_pbt_optimizer(
30 self, key: chex.PRNGKey
31 ) -> tuple[chex.ArrayTree, PBTOptState]:
32 pop = PyTreeDict(
33 lr=log_uniform_init(self.config.search_space.lr, key, self.config.pop_size)
34 )
35
36 return pop, PBTOptState()
37
[docs]
38 def apply_hyperparams_to_workflow_state(
39 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric]
40 ) -> State:
41 opt_state = workflow_state.opt_state
42 assert isinstance(opt_state, InjectStatefulHyperparamsState)
43 # InjectStatefulHyperparamsState is NamedTuple, which is not immutable.
44 opt_state = tree_deepcopy(opt_state)
45 opt_state.hyperparams["learning_rate"] = hyperparams.lr
46 return workflow_state.replace(opt_state=opt_state)