Source code for evorl.algorithms.meta.pbt_sac.pbt_param_sac
1import chex
2import jax
3
4from evorl.types import PyTreeDict, State
5
6from ..pbt_workflow import PBTOffpolicyWorkflowTemplate, PBTOptState
7from ..pbt_utils import uniform_init, log_uniform_init
8
9
[docs]
10class PBTParamSACWorkflow(PBTOffpolicyWorkflowTemplate):
[docs]
11 @classmethod
12 def name(cls):
13 return "PBT-ParamSAC"
14
15 def _setup_pop_and_pbt_optimizer(
16 self, key: chex.PRNGKey
17 ) -> tuple[chex.ArrayTree, PBTOptState]:
18 search_space = self.config.search_space
19 pop_size = self.config.pop_size
20
21 def _init(hp, key):
22 match hp:
23 case "discount_g" | "log_alpha":
24 return uniform_init(search_space[hp], key, pop_size)
25 case "actor_loss_weight" | "critic_loss_weight":
26 return log_uniform_init(search_space[hp], key, pop_size)
27
28 pop = PyTreeDict(
29 {
30 hp: _init(hp, key)
31 for hp, key in zip(
32 search_space.keys(), jax.random.split(key, len(search_space))
33 )
34 }
35 )
36
37 return pop, PBTOptState()
38
[docs]
39 def apply_hyperparams_to_workflow_state(
40 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric]
41 ):
42 agent_state = workflow_state.agent_state
43 agent_state = agent_state.replace(
44 params=agent_state.params.replace(
45 log_alpha=hyperparams.log_alpha,
46 ),
47 extra_state=agent_state.extra_state.replace(
48 discount_g=hyperparams.discount_g,
49 ),
50 )
51
52 # make a shadow copy
53 hyperparams = hyperparams.replace()
54 hyperparams.pop("discount_g")
55 hyperparams.pop("log_alpha")
56 hp_state = workflow_state.hp_state.replace(**hyperparams)
57
58 return workflow_state.replace(
59 agent_state=agent_state,
60 hp_state=hp_state,
61 )