1from omegaconf import OmegaConf
2
3import chex
4import jax
5import jax.numpy as jnp
6
7from evorl.types import PyTreeDict, State
8from evorl.utils.jax_utils import tree_zeros_like, tree_set, tree_get
9
10from ..pbt_workflow import PBTOptState
11from ..pbt_utils import uniform_init, log_uniform_init
12from .pbt_param_ppo import PBTParamPPOWorkflow
13
14
[docs]
15class PBTCSOOptState(PBTOptState):
16 velocity: chex.ArrayTree
17
18
[docs]
19class PBTCSOParamPPOWorkflow(PBTParamPPOWorkflow):
[docs]
20 @classmethod
21 def name(cls):
22 return "PBT-CSO-ParamPPO"
23
24 def _setup_pop_and_pbt_optimizer(
25 self, key: chex.PRNGKey
26 ) -> tuple[chex.ArrayTree, PBTOptState]:
27 search_space = self.config.search_space
28 pop_size = self.config.pop_size
29
30 assert pop_size % 2 == 0, "pop_size must be even"
31
32 def _init(hp, key):
33 match hp:
34 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon":
35 return log_uniform_init(search_space[hp], key, pop_size)
36 case "entropy_loss_weight":
37 return -log_uniform_init(
38 OmegaConf.create(
39 dict(low=-search_space[hp].high, high=-search_space[hp].low)
40 ),
41 key,
42 pop_size,
43 )
44 case "discount_g" | "gae_lambda_g":
45 return uniform_init(search_space[hp], key, pop_size)
46
47 pop = PyTreeDict(
48 {
49 hp: _init(hp, key)
50 for hp, key in zip(
51 search_space.keys(), jax.random.split(key, len(search_space))
52 )
53 }
54 )
55
56 pbt_opt_state = PBTCSOOptState(velocity=tree_zeros_like(pop))
57
58 return pop, pbt_opt_state
59
[docs]
60 def exploit_and_explore(
61 self,
62 pbt_opt_state: PBTOptState, # shared
63 pop: chex.ArrayTree, # sharding
64 pop_workflow_state: State, # sharding
65 pop_metrics: chex.ArrayTree, # sharding
66 key: chex.PRNGKey, # shared
67 ) -> tuple[chex.ArrayTree, State, PBTOptState]:
68 pairing_key, rand_key = jax.random.split(key)
69 velocity = pbt_opt_state.velocity # PyTreeDict
70 pop_size = self.config.pop_size
71 search_space = self.config.search_space
72
73 randperm = jax.random.permutation(pairing_key, pop_size).reshape(2, -1)
74
75 mask = pop_metrics[randperm[0]] > pop_metrics[randperm[1]]
76 teacher_indices = jnp.where(mask, randperm[0], randperm[1]) # fast learner
77 student_indices = jnp.where(mask, randperm[1], randperm[0]) # slow learner
78
79 students_velocity = PyTreeDict()
80 offsprings = PyTreeDict()
81 for hp, key in zip(
82 search_space.keys(), jax.random.split(rand_key, len(search_space))
83 ):
84 v = velocity[hp]
85 x = pop[hp]
86
87 match hp:
88 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon":
89 # compute the velocity in log space
90 x = jnp.log(x)
91 case "entropy_loss_weight":
92 x = jnp.log(-x)
93 case "discount_g" | "gae_lambda_g":
94 pass
95
96 r1_key, r2_key = jax.random.split(key)
97 chex.assert_equal_shape((v, x))
98 r1 = jax.random.uniform(r1_key, shape=(pop_size // 2, *v.shape[1:]))
99 r2 = jax.random.uniform(r2_key, shape=(pop_size // 2, *v.shape[1:]))
100 v_stu = r1 * v[student_indices] + r2 * (
101 x[teacher_indices] - x[student_indices]
102 )
103
104 x_stu = x[student_indices] + v_stu
105
106 # turn back to original space
107 match hp:
108 case "actor_loss_weight" | "critic_loss_weight" | "clip_epsilon":
109 x_stu = jnp.exp(x_stu)
110 case "entropy_loss_weight":
111 x_stu = -jnp.exp(x_stu)
112 case "discount_g" | "gae_lambda_g":
113 pass
114
115 x_stu = jnp.clip(
116 x_stu, min=search_space[hp]["low"], max=search_space[hp]["high"]
117 )
118
119 students_velocity[hp] = v_stu
120 offsprings[hp] = x_stu
121
122 # Note: no need to deepcopy teachers_wf_state here, since it should be
123 # ensured immutable in apply_hyperparams_to_workflow_state()
124 teachers_wf_state = tree_get(pop_workflow_state, teacher_indices)
125 offsprings_workflow_state = jax.vmap(self.apply_hyperparams_to_workflow_state)(
126 teachers_wf_state, offsprings
127 )
128
129 velocity = tree_set(
130 velocity, students_velocity, student_indices, unique_indices=True
131 )
132 pbt_opt_state = pbt_opt_state.replace(velocity=velocity)
133
134 # ==== survival | merge population ====
135 pop = tree_set(pop, offsprings, student_indices, unique_indices=True)
136 pop_workflow_state = tree_set(
137 pop_workflow_state,
138 offsprings_workflow_state,
139 student_indices,
140 unique_indices=True,
141 )
142
143 return pop, pop_workflow_state, pbt_opt_state