1import logging
2from omegaconf import DictConfig
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4from collections.abc import Callable
5
6import chex
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11
12from evorl.types import Params, pytree_field, PyTreeDict
13from evorl.envs import AutoresetMode, create_env
14from evorl.evaluators import Evaluator
15from evorl.agent import AgentState
16from evorl.utils.jax_utils import rng_split_like_tree
17from evorl.ec.optimizers import EvoOptimizer, ECState
18from evorl.ec.optimizers.openes import compute_centered_ranks, OpenESState
19from evorl.ec.optimizers.utils import weight_sum
20
21from evorl.algorithms.ec.so.openes import OpenESWorkflow
22from evorl.algorithms.ec.ec_agent import make_deterministic_ec_agent
23
24
25logger = logging.getLogger(__name__)
26
27
[docs]
28class OpenES(EvoOptimizer):
29 pop_size: int
30 lr: float
31 noise_std: float
32 mirror_sampling: bool = True
33 weight_decay: float | None = None
34
35 fitness_shaping_fn: Callable[[chex.Array], chex.Array] = pytree_field(
36 static=True, default=compute_centered_ranks
37 )
38 optimizer: optax.GradientTransformation = pytree_field(static=True, init=False)
39
40 def __post_init__(self):
41 assert self.pop_size > 0, "pop_size must be positive"
42 if self.mirror_sampling:
43 assert self.pop_size % 2 == 0, "pop_size must be even for mirror sampling"
44
45 self.optimizer = optax.inject_hyperparams(
46 optax.adam, static_args=("b1", "b2", "eps", "eps_root")
47 )(learning_rate=self.lr)
48
[docs]
49 def init(self, mean: Params, key: chex.PRNGKey) -> ECState:
50 return OpenESState(
51 mean=mean,
52 opt_state=self.optimizer.init(mean),
53 noise_std=jnp.float32(self.noise_std),
54 key=key,
55 )
56
[docs]
57 def ask(self, state: ECState) -> tuple[chex.ArrayTree, ECState]:
58 """Generate new candidate solutions."""
59 key, sample_key = jax.random.split(state.key)
60 sample_keys = rng_split_like_tree(sample_key, state.mean)
61
62 if self.mirror_sampling:
63 noise = jtu.tree_map(
64 lambda x, k: jax.random.normal(k, shape=(self.pop_size // 2, *x.shape)),
65 state.mean,
66 sample_keys,
67 )
68 noise = jtu.tree_map(lambda z: jnp.concatenate([z, -z], axis=0), noise)
69 else:
70 noise = jtu.tree_map(
71 lambda x, k: jax.random.normal(k, shape=(self.pop_size, *x.shape)),
72 state.mean,
73 sample_keys,
74 )
75
76 pop = jtu.tree_map(
77 lambda m, z: m + state.noise_std * z,
78 state.mean,
79 noise,
80 )
81 state = state.replace(key=key, noise=noise)
82
83 return pop, state
84
[docs]
85 def tell(
86 self, state: ECState, fitnesses: chex.Array
87 ) -> tuple[PyTreeDict, OpenESState]:
88 transformed_fitnesses = self.fitness_shaping_fn(fitnesses)
89
90 # grad = 1/(N*sigma^2) * sum(F_i*(x_i-m))
91 grad = jtu.tree_map(
92 # Note: we need additional "-1.0" since we are maximizing the fitness
93 lambda z: (
94 -weight_sum(z, transformed_fitnesses)
95 / (self.pop_size * state.noise_std)
96 ),
97 state.noise,
98 )
99
100 # add L2 weight decay
101 if self.weight_decay is not None:
102 grad = jtu.tree_map(
103 lambda g, x: g + self.weight_decay * x,
104 grad,
105 state.mean,
106 )
107
108 update, opt_state = self.optimizer.update(grad, state.opt_state)
109 mean = optax.apply_updates(state.mean, update)
110
111 return PyTreeDict(), state.replace(mean=mean, opt_state=opt_state, noise=None)
112
113
[docs]
114class ParamOpenESWorkflow(OpenESWorkflow):
[docs]
115 @classmethod
116 def name(cls):
117 return "ParamOpenES"
118
119 @classmethod
120 def _build_from_config(cls, config: DictConfig) -> Self:
121 env = create_env(
122 config.env,
123 episode_length=config.env.max_episode_steps,
124 parallel=config.num_envs,
125 autoreset_mode=AutoresetMode.DISABLED,
126 )
127
128 agent = make_deterministic_ec_agent(
129 action_space=env.action_space,
130 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
131 use_bias=config.agent_network.use_bias,
132 normalize_obs=config.normalize_obs,
133 norm_layer_type=config.agent_network.norm_layer_type,
134 )
135
136 ec_optimizer = OpenES(
137 pop_size=config.pop_size,
138 lr=config.ec_lr,
139 noise_std=config.ec_noise_std,
140 mirror_sampling=config.mirror_sampling,
141 weight_decay=config.weight_decay,
142 )
143
144 if config.explore:
145 action_fn = agent.compute_actions
146 else:
147 action_fn = agent.evaluate_actions
148
149 ec_evaluator = Evaluator(
150 env=env,
151 action_fn=action_fn,
152 max_episode_steps=config.env.max_episode_steps,
153 discount=config.discount,
154 )
155
156 # to evaluate the pop-mean actor
157 eval_env = create_env(
158 config.env,
159 episode_length=config.env.max_episode_steps,
160 parallel=config.num_eval_envs,
161 autoreset_mode=AutoresetMode.DISABLED,
162 )
163
164 evaluator = Evaluator(
165 env=eval_env,
166 action_fn=agent.evaluate_actions,
167 max_episode_steps=config.env.max_episode_steps,
168 )
169
170 agent_state_vmap_axes = AgentState(
171 params=0,
172 obs_preprocessor_state=None,
173 )
174
175 return cls(
176 config=config,
177 env=env,
178 agent=agent,
179 ec_optimizer=ec_optimizer,
180 ec_evaluator=ec_evaluator,
181 evaluator=evaluator,
182 agent_state_vmap_axes=agent_state_vmap_axes,
183 )