Source code for evorl.algorithms.meta.pbt

 1import logging
 2
 3import chex
 4import optax
 5from optax.schedules import InjectStatefulHyperparamsState
 6
 7from evorl.types import PyTreeDict, State
 8from evorl.utils.jax_utils import tree_deepcopy
 9
10from .pbt_workflow import PBTWorkflowTemplate, PBTOptState
11from .pbt_utils import log_uniform_init
12
13logger = logging.getLogger(__name__)
14
15
[docs] 16class PBTWorkflow(PBTWorkflowTemplate): 17 """A minimal Example of PBT that tunes the lr of PPO.""" 18
[docs] 19 @classmethod 20 def name(cls): 21 return "PBT"
22 23 def _customize_optimizer(self) -> None: 24 """Customize the target workflow's optimizer.""" 25 self.workflow.optimizer = optax.inject_hyperparams( 26 optax.adam, static_args=("b1", "b2", "eps", "eps_root") 27 )(learning_rate=self.config.search_space.lr.low) 28 29 def _setup_pop_and_pbt_optimizer( 30 self, key: chex.PRNGKey 31 ) -> tuple[chex.ArrayTree, PBTOptState]: 32 pop = PyTreeDict( 33 lr=log_uniform_init(self.config.search_space.lr, key, self.config.pop_size) 34 ) 35 36 return pop, PBTOptState() 37
[docs] 38 def apply_hyperparams_to_workflow_state( 39 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric] 40 ) -> State: 41 opt_state = workflow_state.opt_state 42 assert isinstance(opt_state, InjectStatefulHyperparamsState) 43 # InjectStatefulHyperparamsState is NamedTuple, which is not immutable. 44 opt_state = tree_deepcopy(opt_state) 45 opt_state.hyperparams["learning_rate"] = hyperparams.lr 46 return workflow_state.replace(opt_state=opt_state)