Source code for evorl.algorithms.meta.pbt_ppo.pbt_param_ppo

 1import chex
 2import jax
 3from omegaconf import OmegaConf
 4
 5from evorl.types import PyTreeDict, State
 6
 7from ..pbt_workflow import PBTWorkflowTemplate, PBTOptState
 8from ..pbt_utils import uniform_init, log_uniform_init
 9
10
[docs] 11class PBTParamPPOWorkflow(PBTWorkflowTemplate):
[docs] 12 @classmethod 13 def name(cls): 14 return "PBT-ParamPPO"
15 16 def _setup_pop_and_pbt_optimizer( 17 self, key: chex.PRNGKey 18 ) -> tuple[chex.ArrayTree, PBTOptState]: 19 search_space = self.config.search_space 20 pop_size = self.config.pop_size 21 22 def _init(hp, key): 23 match hp: 24 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon": 25 return log_uniform_init(search_space[hp], key, pop_size) 26 case "entropy_loss_weight": 27 return -log_uniform_init( 28 OmegaConf.create( 29 dict(low=-search_space[hp].high, high=-search_space[hp].low) 30 ), 31 key, 32 pop_size, 33 ) 34 case "discount_g" | "gae_lambda_g": 35 return uniform_init(search_space[hp], key, pop_size) 36 37 pop = PyTreeDict( 38 { 39 hp: _init(hp, key) 40 for hp, key in zip( 41 search_space.keys(), jax.random.split(key, len(search_space)) 42 ) 43 } 44 ) 45 46 return pop, PBTOptState() 47
[docs] 48 def apply_hyperparams_to_workflow_state( 49 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric] 50 ): 51 agent_state = workflow_state.agent_state 52 agent_state = agent_state.replace( 53 extra_state=agent_state.extra_state.replace( 54 clip_epsilon=hyperparams.clip_epsilon 55 ) 56 ) 57 58 # make a shadow copy 59 hyperparams = hyperparams.replace() 60 hyperparams.pop("clip_epsilon") 61 hp_state = workflow_state.hp_state.replace(**hyperparams) 62 63 return workflow_state.replace( 64 agent_state=agent_state, 65 hp_state=hp_state, 66 )