Source code for evorl.algorithms.meta.pbt_openes.pbt_param_openes

 1import chex
 2import jax
 3
 4from evorl.types import PyTreeDict, State
 5from evorl.metrics import EvaluateMetric
 6from evorl.utils.jax_utils import tree_deepcopy
 7from evorl.distributed import shmap_vmap
 8
 9from ..pbt_workflow import PBTWorkflowTemplate, PBTOptState, PBTEvalMetric
10from ..pbt_utils import log_uniform_init
11
12
[docs] 13class PBTParamOpenESWorkflow(PBTWorkflowTemplate):
[docs] 14 @classmethod 15 def name(cls): 16 return "PBT-ParamOpenES"
17 18 def _setup_pop_and_pbt_optimizer( 19 self, key: chex.PRNGKey 20 ) -> tuple[chex.ArrayTree, PBTOptState]: 21 search_space = self.config.search_space 22 pop_size = self.config.pop_size 23 24 def _init(hp, key): 25 # "ec_noise_std" | "ec_lr" 26 return log_uniform_init(search_space[hp], key, pop_size) 27 28 pop = PyTreeDict( 29 { 30 hp: _init(hp, key) 31 for hp, key in zip( 32 search_space.keys(), jax.random.split(key, len(search_space)) 33 ) 34 } 35 ) 36 37 return pop, PBTOptState() 38
[docs] 39 def apply_hyperparams_to_workflow_state( 40 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric] 41 ): 42 ec_opt_state = workflow_state.ec_opt_state 43 44 optax_opt_state = ec_opt_state.opt_state 45 optax_opt_state = tree_deepcopy(optax_opt_state) 46 optax_opt_state.hyperparams["learning_rate"] = hyperparams.ec_lr 47 48 ec_opt_state = ec_opt_state.replace( 49 opt_state=optax_opt_state, 50 noise_std=hyperparams.ec_noise_std, 51 ) 52 53 return workflow_state.replace(ec_opt_state=ec_opt_state)
54
[docs] 55 def evaluate(self, state: State) -> State: 56 key, eval_key = jax.random.split(state.key, num=2) 57 58 def _evaluate(wf_state, key): 59 # eval pop-center 60 agent_state = self.workflow._get_pop_center(wf_state) 61 62 # [#episodes] 63 raw_eval_metrics = self.evaluator.evaluate( 64 agent_state, key, num_episodes=self.config.eval_episodes 65 ) 66 67 eval_metrics = EvaluateMetric( 68 episode_returns=raw_eval_metrics.episode_returns.mean(), 69 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 70 ) 71 return eval_metrics 72 73 eval_fn = shmap_vmap( 74 _evaluate, 75 mesh=self.sharding.mesh, 76 in_specs=self.sharding.spec, 77 out_specs=self.sharding.spec, 78 check_rep=False, 79 ) 80 81 pop_eval_metrics = eval_fn( 82 state.pop_workflow_state, jax.random.split(eval_key, self.config.pop_size) 83 ) 84 85 eval_metrics = PBTEvalMetric( 86 pop_episode_returns=pop_eval_metrics.episode_returns, 87 pop_episode_lengths=pop_eval_metrics.episode_lengths, 88 ) 89 90 return eval_metrics, state.replace(key=key)