Source code for evorl.algorithms.meta.pbt_openes.pbt_param_openes
1import chex
2import jax
3
4from evorl.types import PyTreeDict, State
5from evorl.metrics import EvaluateMetric
6from evorl.utils.jax_utils import tree_deepcopy
7from evorl.distributed import shmap_vmap
8
9from ..pbt_workflow import PBTWorkflowTemplate, PBTOptState, PBTEvalMetric
10from ..pbt_utils import log_uniform_init
11
12
[docs]
13class PBTParamOpenESWorkflow(PBTWorkflowTemplate):
[docs]
14 @classmethod
15 def name(cls):
16 return "PBT-ParamOpenES"
17
18 def _setup_pop_and_pbt_optimizer(
19 self, key: chex.PRNGKey
20 ) -> tuple[chex.ArrayTree, PBTOptState]:
21 search_space = self.config.search_space
22 pop_size = self.config.pop_size
23
24 def _init(hp, key):
25 # "ec_noise_std" | "ec_lr"
26 return log_uniform_init(search_space[hp], key, pop_size)
27
28 pop = PyTreeDict(
29 {
30 hp: _init(hp, key)
31 for hp, key in zip(
32 search_space.keys(), jax.random.split(key, len(search_space))
33 )
34 }
35 )
36
37 return pop, PBTOptState()
38
[docs]
39 def apply_hyperparams_to_workflow_state(
40 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric]
41 ):
42 ec_opt_state = workflow_state.ec_opt_state
43
44 optax_opt_state = ec_opt_state.opt_state
45 optax_opt_state = tree_deepcopy(optax_opt_state)
46 optax_opt_state.hyperparams["learning_rate"] = hyperparams.ec_lr
47
48 ec_opt_state = ec_opt_state.replace(
49 opt_state=optax_opt_state,
50 noise_std=hyperparams.ec_noise_std,
51 )
52
53 return workflow_state.replace(ec_opt_state=ec_opt_state)
54
[docs]
55 def evaluate(self, state: State) -> State:
56 key, eval_key = jax.random.split(state.key, num=2)
57
58 def _evaluate(wf_state, key):
59 # eval pop-center
60 agent_state = self.workflow._get_pop_center(wf_state)
61
62 # [#episodes]
63 raw_eval_metrics = self.evaluator.evaluate(
64 agent_state, key, num_episodes=self.config.eval_episodes
65 )
66
67 eval_metrics = EvaluateMetric(
68 episode_returns=raw_eval_metrics.episode_returns.mean(),
69 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
70 )
71 return eval_metrics
72
73 eval_fn = shmap_vmap(
74 _evaluate,
75 mesh=self.sharding.mesh,
76 in_specs=self.sharding.spec,
77 out_specs=self.sharding.spec,
78 check_rep=False,
79 )
80
81 pop_eval_metrics = eval_fn(
82 state.pop_workflow_state, jax.random.split(eval_key, self.config.pop_size)
83 )
84
85 eval_metrics = PBTEvalMetric(
86 pop_episode_returns=pop_eval_metrics.episode_returns,
87 pop_episode_lengths=pop_eval_metrics.episode_lengths,
88 )
89
90 return eval_metrics, state.replace(key=key)