Source code for evorl.algorithms.meta.pbt_ppo.pbt_cso_param_ppo

  1from omegaconf import OmegaConf
  2
  3import chex
  4import jax
  5import jax.numpy as jnp
  6
  7from evorl.types import PyTreeDict, State
  8from evorl.utils.jax_utils import tree_zeros_like, tree_set, tree_get
  9
 10from ..pbt_workflow import PBTOptState
 11from ..pbt_utils import uniform_init, log_uniform_init
 12from .pbt_param_ppo import PBTParamPPOWorkflow
 13
 14
[docs] 15class PBTCSOOptState(PBTOptState): 16 velocity: chex.ArrayTree
17 18
[docs] 19class PBTCSOParamPPOWorkflow(PBTParamPPOWorkflow):
[docs] 20 @classmethod 21 def name(cls): 22 return "PBT-CSO-ParamPPO"
23 24 def _setup_pop_and_pbt_optimizer( 25 self, key: chex.PRNGKey 26 ) -> tuple[chex.ArrayTree, PBTOptState]: 27 search_space = self.config.search_space 28 pop_size = self.config.pop_size 29 30 assert pop_size % 2 == 0, "pop_size must be even" 31 32 def _init(hp, key): 33 match hp: 34 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon": 35 return log_uniform_init(search_space[hp], key, pop_size) 36 case "entropy_loss_weight": 37 return -log_uniform_init( 38 OmegaConf.create( 39 dict(low=-search_space[hp].high, high=-search_space[hp].low) 40 ), 41 key, 42 pop_size, 43 ) 44 case "discount_g" | "gae_lambda_g": 45 return uniform_init(search_space[hp], key, pop_size) 46 47 pop = PyTreeDict( 48 { 49 hp: _init(hp, key) 50 for hp, key in zip( 51 search_space.keys(), jax.random.split(key, len(search_space)) 52 ) 53 } 54 ) 55 56 pbt_opt_state = PBTCSOOptState(velocity=tree_zeros_like(pop)) 57 58 return pop, pbt_opt_state 59
[docs] 60 def exploit_and_explore( 61 self, 62 pbt_opt_state: PBTOptState, # shared 63 pop: chex.ArrayTree, # sharding 64 pop_workflow_state: State, # sharding 65 pop_metrics: chex.ArrayTree, # sharding 66 key: chex.PRNGKey, # shared 67 ) -> tuple[chex.ArrayTree, State, PBTOptState]: 68 pairing_key, rand_key = jax.random.split(key) 69 velocity = pbt_opt_state.velocity # PyTreeDict 70 pop_size = self.config.pop_size 71 search_space = self.config.search_space 72 73 randperm = jax.random.permutation(pairing_key, pop_size).reshape(2, -1) 74 75 mask = pop_metrics[randperm[0]] > pop_metrics[randperm[1]] 76 teacher_indices = jnp.where(mask, randperm[0], randperm[1]) # fast learner 77 student_indices = jnp.where(mask, randperm[1], randperm[0]) # slow learner 78 79 students_velocity = PyTreeDict() 80 offsprings = PyTreeDict() 81 for hp, key in zip( 82 search_space.keys(), jax.random.split(rand_key, len(search_space)) 83 ): 84 v = velocity[hp] 85 x = pop[hp] 86 87 match hp: 88 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon": 89 # compute the velocity in log space 90 x = jnp.log(x) 91 case "entropy_loss_weight": 92 x = jnp.log(-x) 93 case "discount_g" | "gae_lambda_g": 94 pass 95 96 r1_key, r2_key = jax.random.split(key) 97 chex.assert_equal_shape((v, x)) 98 r1 = jax.random.uniform(r1_key, shape=(pop_size // 2, *v.shape[1:])) 99 r2 = jax.random.uniform(r2_key, shape=(pop_size // 2, *v.shape[1:])) 100 v_stu = r1 * v[student_indices] + r2 * ( 101 x[teacher_indices] - x[student_indices] 102 ) 103 104 x_stu = x[student_indices] + v_stu 105 106 # turn back to original space 107 match hp: 108 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon": 109 x_stu = jnp.exp(x_stu) 110 case "entropy_loss_weight": 111 x_stu = -jnp.exp(x_stu) 112 case "discount_g" | "gae_lambda_g": 113 pass 114 115 x_stu = jnp.clip( 116 x_stu, min=search_space[hp]["low"], max=search_space[hp]["high"] 117 ) 118 119 students_velocity[hp] = v_stu 120 offsprings[hp] = x_stu 121 122 # Note: no need to deepcopy teachers_wf_state here, since it should be 123 # ensured immutable in apply_hyperparams_to_workflow_state() 124 teachers_wf_state = tree_get(pop_workflow_state, teacher_indices) 125 offsprings_workflow_state = jax.vmap(self.apply_hyperparams_to_workflow_state)( 126 teachers_wf_state, offsprings 127 ) 128 129 velocity = tree_set( 130 velocity, students_velocity, student_indices, unique_indices=True 131 ) 132 pbt_opt_state = pbt_opt_state.replace(velocity=velocity) 133 134 # ==== survival | merge population ==== 135 pop = tree_set(pop, offsprings, student_indices, unique_indices=True) 136 pop_workflow_state = tree_set( 137 pop_workflow_state, 138 offsprings_workflow_state, 139 student_indices, 140 unique_indices=True, 141 ) 142 143 return pop, pop_workflow_state, pbt_opt_state