Source code for evorl.algorithms.meta.pbt_utils
1import pandas as pd
2
3import chex
4import jax
5import jax.numpy as jnp
6
7
[docs]
8def convert_pop_to_df(pop):
9 df = pd.DataFrame.from_dict(pop)
10 df.insert(0, "pop_id", range(len(df)))
11 return df
12
13
[docs]
14def uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array:
15 """Random sample."""
16 assert search_space.low <= search_space.high
17 return jax.random.uniform(
18 key,
19 (num,),
20 minval=search_space.low,
21 maxval=search_space.high,
22 )
23
24
[docs]
25def truncated_normal_init(
26 search_space, key: chex.PRNGKey, num: int, m=0.95
27) -> chex.Array:
28 """Random sample from a truncated normal distribution."""
29 assert search_space.low <= search_space.high
30
31 # Note: 1.96 is the z-score for 95% confidence interval,
32 # meaning a value sampled from this distribution will be within
33
34 z = jax.scipy.stats.norm.ppf(1 - (1 - m) / 2)
35
36 mu = (search_space.high + search_space.low) / 2
37
38 std = (search_space.high - search_space.low) / (2 * z)
39
40 samples = jax.random.truncated_normal(
41 key,
42 lower=(search_space.low - mu) / std,
43 upper=(search_space.high - mu) / std,
44 shape=(num,),
45 )
46 samples = samples * std + mu
47
48 return
49
50
[docs]
51def log_uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array:
52 """Random sample from log space.
53
54 Suitable for hyperparameters that need explore different magnitudes in positive range. eg: [1e-3, 100].
55 """
56 assert (
57 search_space.low > 0
58 and search_space.high > 0
59 and search_space.low <= search_space.high
60 )
61
62 return jnp.exp(
63 jax.random.uniform(
64 key,
65 (num,),
66 minval=jnp.log(search_space.low),
67 maxval=jnp.log(search_space.high),
68 )
69 )
70
71
[docs]
72def exp_uniform_init(search_space, key: chex.PRNGKey, num: int) -> chex.Array:
73 """Sample from exp(-x)."""
74 assert (
75 search_space.low > 0
76 and search_space.high > 0
77 and search_space.low <= search_space.high
78 )
79
80 return jnp.exp(
81 -jax.random.uniform(
82 key,
83 (num,),
84 minval=search_space.low,
85 maxval=search_space.high,
86 )
87 )