Source code for evorl.algorithms.meta.pbt_sac.pbt_param_sac

 1import chex
 2import jax
 3
 4from evorl.types import PyTreeDict, State
 5
 6from ..pbt_workflow import PBTOffpolicyWorkflowTemplate, PBTOptState
 7from ..pbt_utils import uniform_init, log_uniform_init
 8
 9
[docs] 10class PBTParamSACWorkflow(PBTOffpolicyWorkflowTemplate):
[docs] 11 @classmethod 12 def name(cls): 13 return "PBT-ParamSAC"
14 15 def _setup_pop_and_pbt_optimizer( 16 self, key: chex.PRNGKey 17 ) -> tuple[chex.ArrayTree, PBTOptState]: 18 search_space = self.config.search_space 19 pop_size = self.config.pop_size 20 21 def _init(hp, key): 22 match hp: 23 case "discount_g" | "log_alpha": 24 return uniform_init(search_space[hp], key, pop_size) 25 case "actor_loss_weight" | "critic_loss_weight": 26 return log_uniform_init(search_space[hp], key, pop_size) 27 28 pop = PyTreeDict( 29 { 30 hp: _init(hp, key) 31 for hp, key in zip( 32 search_space.keys(), jax.random.split(key, len(search_space)) 33 ) 34 } 35 ) 36 37 return pop, PBTOptState() 38
[docs] 39 def apply_hyperparams_to_workflow_state( 40 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric] 41 ): 42 agent_state = workflow_state.agent_state 43 agent_state = agent_state.replace( 44 params=agent_state.params.replace( 45 log_alpha=hyperparams.log_alpha, 46 ), 47 extra_state=agent_state.extra_state.replace( 48 discount_g=hyperparams.discount_g, 49 ), 50 ) 51 52 # make a shadow copy 53 hyperparams = hyperparams.replace() 54 hyperparams.pop("discount_g") 55 hyperparams.pop("log_alpha") 56 hp_state = workflow_state.hp_state.replace(**hyperparams) 57 58 return workflow_state.replace( 59 agent_state=agent_state, 60 hp_state=hp_state, 61 )