Source code for evorl.algorithms.meta.pbt_operations

 1import chex
 2import jax
 3import jax.numpy as jnp
 4from evorl.types import PyTreeDict
 5
 6
[docs] 7def explore( 8 parent: chex.ArrayTree, 9 key: chex.PRNGKey, 10 perturb_factor: dict[str, float], 11 search_space: dict[str, dict[str, float]], 12): 13 """Define the exploration operation for PBT. 14 15 Normally explore the local of an individual. 16 i.e., mutation op in the context of EC. 17 Here we use the orginal exploration operator in PBT. 18 """ 19 offspring = PyTreeDict() 20 for hp_name in parent.keys(): 21 val = parent[hp_name] * ( 22 1 23 + jax.random.uniform( 24 key, 25 minval=-perturb_factor[hp_name], 26 maxval=perturb_factor[hp_name], 27 ) 28 ) 29 offspring[hp_name] = jnp.clip( 30 val, min=search_space[hp_name]["low"], max=search_space[hp_name]["high"] 31 ) 32 33 return offspring
34 35
[docs] 36def select( 37 pop_episode_returns: chex.Array, key: chex.PRNGKey, bottoms_num: int, tops_num: int 38): 39 """Select parents to replace worse individuals.""" 40 indices = jnp.argsort(pop_episode_returns) 41 bottoms_indices = indices[:bottoms_num] 42 tops_indices = indices[-tops_num:] 43 44 # replace bottoms with random tops 45 tops_choice_indices = jax.random.choice( 46 key, tops_indices, (bottoms_num,), replace=True 47 ) # ensure selecting (pop_size*bottom_ratio) parents from top 48 49 return tops_choice_indices, bottoms_indices