Source code for evorl.algorithms.meta.pbt_utils

 1import pandas as pd
 2
 3import chex
 4import jax
 5import jax.numpy as jnp
 6
 7
[docs] 8def convert_pop_to_df(pop): 9 df = pd.DataFrame.from_dict(pop) 10 df.insert(0, "pop_id", range(len(df))) 11 return df
12 13
[docs] 14def uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array: 15 """Random sample.""" 16 assert search_space.low <= search_space.high 17 return jax.random.uniform( 18 key, 19 (num,), 20 minval=search_space.low, 21 maxval=search_space.high, 22 )
23 24
[docs] 25def truncated_normal_init( 26 search_space, key: chex.PRNGKey, num: int, m=0.95 27) -> chex.Array: 28 """Random sample from a truncated normal distribution.""" 29 assert search_space.low <= search_space.high 30 31 # Note: 1.96 is the z-score for 95% confidence interval, 32 # meaning a value sampled from this distribution will be within 33 34 z = jax.scipy.stats.norm.ppf(1 - (1 - m) / 2) 35 36 mu = (search_space.high + search_space.low) / 2 37 38 std = (search_space.high - search_space.low) / (2 * z) 39 40 samples = jax.random.truncated_normal( 41 key, 42 lower=(search_space.low - mu) / std, 43 upper=(search_space.high - mu) / std, 44 shape=(num,), 45 ) 46 samples = samples * std + mu 47 48 return
49 50
[docs] 51def log_uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array: 52 """Random sample from log space. 53 54 Suitable for hyperparameters that need explore different magnitudes in positive range. eg: [1e-3, 100]. 55 """ 56 assert ( 57 search_space.low > 0 58 and search_space.high > 0 59 and search_space.low <= search_space.high 60 ) 61 62 return jnp.exp( 63 jax.random.uniform( 64 key, 65 (num,), 66 minval=jnp.log(search_space.low), 67 maxval=jnp.log(search_space.high), 68 ) 69 )
70 71
[docs] 72def exp_uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array: 73 """Sample from exp(-x).""" 74 assert ( 75 search_space.low > 0 76 and search_space.high > 0 77 and search_space.low <= search_space.high 78 ) 79 80 return jnp.exp( 81 -jax.random.uniform( 82 key, 83 (num,), 84 minval=search_space.low, 85 maxval=search_space.high, 86 ) 87 )