Source code for evorl.algorithms.meta.pbt_ppo.pbt_param_ppo
1import chex
2import jax
3from omegaconf import OmegaConf
4
5from evorl.types import PyTreeDict, State
6
7from ..pbt_workflow import PBTWorkflowTemplate, PBTOptState
8from ..pbt_utils import uniform_init, log_uniform_init
9
10
[docs]
11class PBTParamPPOWorkflow(PBTWorkflowTemplate):
[docs]
12 @classmethod
13 def name(cls):
14 return "PBT-ParamPPO"
15
16 def _setup_pop_and_pbt_optimizer(
17 self, key: chex.PRNGKey
18 ) -> tuple[chex.ArrayTree, PBTOptState]:
19 search_space = self.config.search_space
20 pop_size = self.config.pop_size
21
22 def _init(hp, key):
23 match hp:
24 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon":
25 return log_uniform_init(search_space[hp], key, pop_size)
26 case "entropy_loss_weight":
27 return -log_uniform_init(
28 OmegaConf.create(
29 dict(low=-search_space[hp].high, high=-search_space[hp].low)
30 ),
31 key,
32 pop_size,
33 )
34 case "discount_g" | "gae_lambda_g":
35 return uniform_init(search_space[hp], key, pop_size)
36
37 pop = PyTreeDict(
38 {
39 hp: _init(hp, key)
40 for hp, key in zip(
41 search_space.keys(), jax.random.split(key, len(search_space))
42 )
43 }
44 )
45
46 return pop, PBTOptState()
47
[docs]
48 def apply_hyperparams_to_workflow_state(
49 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric]
50 ):
51 agent_state = workflow_state.agent_state
52 agent_state = agent_state.replace(
53 extra_state=agent_state.extra_state.replace(
54 clip_epsilon=hyperparams.clip_epsilon
55 )
56 )
57
58 # make a shadow copy
59 hyperparams = hyperparams.replace()
60 hyperparams.pop("clip_epsilon")
61 hp_state = workflow_state.hp_state.replace(**hyperparams)
62
63 return workflow_state.replace(
64 agent_state=agent_state,
65 hp_state=hp_state,
66 )