Source code for evorl.algorithms.meta.pbt_operations
1import chex
2import jax
3import jax.numpy as jnp
4from evorl.types import PyTreeDict
5
6
[docs]
7def explore(
8 parent: chex.ArrayTree,
9 key: chex.PRNGKey,
10 perturb_factor: dict[str, float],
11 search_space: dict[str, dict[str, float]],
12):
13 """Define the exploration operation for PBT.
14
15 Normally explore the local of an individual.
16 i.e., mutation op in the context of EC.
17 Here we use the orginal exploration operator in PBT.
18 """
19 offspring = PyTreeDict()
20 for hp_name in parent.keys():
21 val = parent[hp_name] * (
22 1
23 + jax.random.uniform(
24 key,
25 minval=-perturb_factor[hp_name],
26 maxval=perturb_factor[hp_name],
27 )
28 )
29 offspring[hp_name] = jnp.clip(
30 val, min=search_space[hp_name]["low"], max=search_space[hp_name]["high"]
31 )
32
33 return offspring
34
35
[docs]
36def select(
37 pop_episode_returns: chex.Array, key: chex.PRNGKey, bottoms_num: int, tops_num: int
38):
39 """Select parents to replace worse individuals."""
40 indices = jnp.argsort(pop_episode_returns)
41 bottoms_indices = indices[:bottoms_num]
42 tops_indices = indices[-tops_num:]
43
44 # replace bottoms with random tops
45 tops_choice_indices = jax.random.choice(
46 key, tops_indices, (bottoms_num,), replace=True
47 ) # ensure selecting (pop_size*bottom_ratio) parents from top
48
49 return tops_choice_indices, bottoms_indices