Source code for evorl.algorithms.meta.pbt_openes.param_openes

  1import logging
  2from omegaconf import DictConfig
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4from collections.abc import Callable
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11
 12from evorl.types import Params, pytree_field, PyTreeDict
 13from evorl.envs import AutoresetMode, create_env
 14from evorl.evaluators import Evaluator
 15from evorl.agent import AgentState
 16from evorl.utils.jax_utils import rng_split_like_tree
 17from evorl.ec.optimizers import EvoOptimizer, ECState
 18from evorl.ec.optimizers.openes import compute_centered_ranks, OpenESState
 19from evorl.ec.optimizers.utils import weight_sum
 20
 21from evorl.algorithms.ec.so.openes import OpenESWorkflow
 22from evorl.algorithms.ec.ec_agent import make_deterministic_ec_agent
 23
 24
 25logger = logging.getLogger(__name__)
 26
 27
[docs] 28class OpenES(EvoOptimizer): 29 pop_size: int 30 lr: float 31 noise_std: float 32 mirror_sampling: bool = True 33 weight_decay: float | None = None 34 35 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field( 36 static=True, default=compute_centered_ranks 37 ) 38 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False) 39 40 def __post_init__(self): 41 assert self.pop_size > 0, "pop_size must be positive" 42 if self.mirror_sampling: 43 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling" 44 45 self.optimizer = optax.inject_hyperparams( 46 optax.adam, static_args=("b1", "b2", "eps", "eps_root") 47 )(learning_rate=self.lr) 48
[docs] 49 def init(self, mean: Params, key: chex.PRNGKey) -> ECState: 50 return OpenESState( 51 mean=mean, 52 opt_state=self.optimizer.init(mean), 53 noise_std=jnp.float32(self.noise_std), 54 key=key, 55 )
56
[docs] 57 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]: 58 """Generate new candidate solutions.""" 59 key, sample_key = jax.random.split(state.key) 60 sample_keys = rng_split_like_tree(sample_key, state.mean) 61 62 if self.mirror_sampling: 63 noise = jtu.tree_map( 64 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)), 65 state.mean, 66 sample_keys, 67 ) 68 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise) 69 else: 70 noise = jtu.tree_map( 71 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape)), 72 state.mean, 73 sample_keys, 74 ) 75 76 pop = jtu.tree_map( 77 lambda m, z: m + state.noise_std * z, 78 state.mean, 79 noise, 80 ) 81 state = state.replace(key=key, noise=noise) 82 83 return pop, state
84
[docs] 85 def tell( 86 self, state: ECState, fitnesses: chex.Array 87 ) -> tuple[PyTreeDict, OpenESState]: 88 transformed_fitnesses = self.fitness_shaping_fn(fitnesses) 89 90 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m)) 91 grad = jtu.tree_map( 92 # Note: we need additional "-1.0" since we are maximizing the fitness 93 lambda z: ( 94 -weight_sum(z, transformed_fitnesses) 95 / (self.pop_size * state.noise_std) 96 ), 97 state.noise, 98 ) 99 100 # add L2 weight decay 101 if self.weight_decay is not None: 102 grad = jtu.tree_map( 103 lambda g, x: g + self.weight_decay * x, 104 grad, 105 state.mean, 106 ) 107 108 update, opt_state = self.optimizer.update(grad, state.opt_state) 109 mean = optax.apply_updates(state.mean, update) 110 111 return PyTreeDict(), state.replace(mean=mean, opt_state=opt_state, noise=None)
112 113
[docs] 114class ParamOpenESWorkflow(OpenESWorkflow):
[docs] 115 @classmethod 116 def name(cls): 117 return "ParamOpenES"
118 119 @classmethod 120 def _build_from_config(cls, config: DictConfig) -> Self: 121 env = create_env( 122 config.env, 123 episode_length=config.env.max_episode_steps, 124 parallel=config.num_envs, 125 autoreset_mode=AutoresetMode.DISABLED, 126 ) 127 128 agent = make_deterministic_ec_agent( 129 action_space=env.action_space, 130 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 131 use_bias=config.agent_network.use_bias, 132 normalize_obs=config.normalize_obs, 133 norm_layer_type=config.agent_network.norm_layer_type, 134 ) 135 136 ec_optimizer = OpenES( 137 pop_size=config.pop_size, 138 lr=config.ec_lr, 139 noise_std=config.ec_noise_std, 140 mirror_sampling=config.mirror_sampling, 141 weight_decay=config.weight_decay, 142 ) 143 144 if config.explore: 145 action_fn = agent.compute_actions 146 else: 147 action_fn = agent.evaluate_actions 148 149 ec_evaluator = Evaluator( 150 env=env, 151 action_fn=action_fn, 152 max_episode_steps=config.env.max_episode_steps, 153 discount=config.discount, 154 ) 155 156 # to evaluate the pop-mean actor 157 eval_env = create_env( 158 config.env, 159 episode_length=config.env.max_episode_steps, 160 parallel=config.num_eval_envs, 161 autoreset_mode=AutoresetMode.DISABLED, 162 ) 163 164 evaluator = Evaluator( 165 env=eval_env, 166 action_fn=agent.evaluate_actions, 167 max_episode_steps=config.env.max_episode_steps, 168 ) 169 170 agent_state_vmap_axes = AgentState( 171 params=0, 172 obs_preprocessor_state=None, 173 ) 174 175 return cls( 176 config=config, 177 env=env, 178 agent=agent, 179 ec_optimizer=ec_optimizer, 180 ec_evaluator=ec_evaluator, 181 evaluator=evaluator, 182 agent_state_vmap_axes=agent_state_vmap_axes, 183 )