1import copy
2import logging
3from omegaconf import DictConfig, OmegaConf
4from typing_extensions import Self
5
6import chex
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10
11from evorl.distributed import POP_AXIS_NAME, all_gather
12from evorl.metrics import (
13 MetricBase,
14 ECWorkflowMetric,
15 MultiObjectiveECWorkflowMetric,
16 ECTrainMetric,
17)
18from evorl.ec.optimizers import EvoOptimizer, ECState
19from evorl.envs import Env
20from evorl.sample_batch import SampleBatch
21from evorl.evaluators import Evaluator, EpisodeCollector
22from evorl.agent import Agent, AgentState, AgentStateAxis
23from evorl.distributed import get_global_ranks, psum, split_key_to_devices
24from evorl.types import State, PyTreeData, pytree_field, Params, PyTreeDict
25from evorl.utils.rl_toolkits import flatten_pop_rollout_episode
26from evorl.utils.jax_utils import tree_stop_gradient
27
28from .workflow import Workflow
29
30
31logger = logging.getLogger(__name__)
32
33
[docs]
34class DistributedInfo(PyTreeData):
35 """Distributed information for multi-devices training."""
36
37 rank: int = jnp.zeros((), dtype=jnp.int32)
38 world_size: int = pytree_field(default=1, static=True)
39
40
[docs]
41class ECWorkflow(Workflow):
42 """Base Workflow for EC (Evolutionary Computation) algorithms."""
43
44 def __init__(self, config: DictConfig):
45 """Initialize the ECWorkflow instance.
46
47 Args:
48 config: the config object
49 """
50 super().__init__(config)
51
52 self.dp_axis_name = None
53 self.devices = jax.local_devices()[:1]
54
55 @property
56 def enable_multi_devices(self) -> bool:
57 """Whether multi-devices training is enabled."""
58 return self.dp_axis_name is not None
59
[docs]
60 @classmethod
61 def build_from_config(
62 cls,
63 config: DictConfig,
64 enable_multi_devices: bool = False,
65 enable_jit: bool = True,
66 ):
67 """Build the ec workflow instance from the config.
68
69 Args:
70 config: Config of the workflow
71 enable_multi_devices: Whether multi-devices training is enabled
72 enable_jit: Whether jit is enabled
73 """
74 config = copy.deepcopy(config) # avoid in-place modification
75
76 devices = jax.local_devices()
77
78 if enable_multi_devices:
79 cls.enable_shmap(POP_AXIS_NAME)
80 OmegaConf.set_readonly(config, False)
81 cls._rescale_config(config)
82 elif enable_jit:
83 cls.enable_jit()
84
85 OmegaConf.set_readonly(config, True)
86
87 workflow = cls._build_from_config(config)
88 if enable_multi_devices:
89 workflow.dp_axis_name = POP_AXIS_NAME
90 workflow.devices = devices
91
92 return workflow
93
94 @classmethod
95 def _build_from_config(cls, config: DictConfig) -> Self:
96 """Customize the process of building the workflow instance from the config.
97
98 Args:
99 config: Config of the workflow
100
101 Returns:
102 workflow: the created workflow instance
103 """
104 raise NotImplementedError
105
106 @classmethod
107 def _rescale_config(cls, config: DictConfig) -> None:
108 """Customize the logic of rescaling the config settings when multi-devices training is enabled.
109
110 When enable_multi_devices=True, rescale config settings in-place to match multi-devices
111
112 Args:
113 config: Config of the workflow
114 """
115 pass
116
[docs]
117 @classmethod
118 def enable_jit(cls) -> None:
119 """Define which methods should be jitted.
120
121 By default, the workflow's `step()` method is jitted.
122 """
123 cls.step = jax.jit(cls.step, static_argnums=(0,))
124
[docs]
125 @classmethod
126 def enable_shmap(cls, axis_name) -> None:
127 """Define which methods should be shmaped.
128
129 This method defines the multi-device behavior. By default, the workflow's `step()` method is shmaped.
130 """
131 pass
132
133
[docs]
134class ECWorkflowTemplate(ECWorkflow):
135 """Workflow template for EC algorithms.
136
137 Attributes:
138 env: Environment object.
139 agent: Workflow-sepecific agent object.
140 ec_optimizer: EC Optimizer of the agent.
141 ec_evaluator: Evaluator object used in `self.evaluation()`.
142 agent_state_vmap_axes: Vmap axis for the agent state.
143 config: Config of the workflow.
144 """
145
146 def __init__(
147 self,
148 *,
149 env: Env,
150 agent: Agent,
151 ec_optimizer: EvoOptimizer,
152 ec_evaluator: Evaluator | EpisodeCollector,
153 agent_state_vmap_axes: AgentStateAxis = 0,
154 config: DictConfig,
155 ):
156 """Initialize the ECWorkflow instance.
157
158 Args:
159 env: Environment object.
160 agent: Workflow-sepecific agent object.
161 ec_optimizer: EC Optimizer of the agent.
162 ec_evaluator: Evaluator object used in `self.evaluation()`.
163 agent_state_vmap_axes: Vmap axis for the agent state.
164 config: Config of the workflow.
165 """
166 super().__init__(config)
167
168 self.agent = agent
169 self.env = env
170 self.ec_optimizer = ec_optimizer
171 self.ec_evaluator = ec_evaluator
172 self.agent_state_vmap_axes = agent_state_vmap_axes
173
174 def _setup_agent_and_optimizer(
175 self, key: chex.PRNGKey
176 ) -> tuple[AgentState, ECState]:
177 """Setup Agent and ECOptimizer states.
178
179 Args:
180 key: JAX PRNGKey
181
182 Returns:
183 Tuple of (agent_state, ec_state)
184 """
185 raise NotImplementedError
186
187 def _setup_workflow_metrics(self) -> MetricBase:
188 """Define Workflow metrics."""
189 return ECWorkflowMetric(best_objective=jnp.finfo(jnp.float32).min)
190
[docs]
191 def setup(self, key: chex.PRNGKey) -> State:
192 key, agent_key = jax.random.split(key, 2)
193
194 # agent_state: store params not optimized by EC (eg: obs_preprocessor_state)
195 agent_state, ec_opt_state = self._setup_agent_and_optimizer(agent_key)
196 workflow_metrics = self._setup_workflow_metrics()
197 distributed_info = DistributedInfo()
198
199 if self.enable_multi_devices:
200 sharding = jax.sharding.NamedSharding(
201 jax.sharding.Mesh(self.devices, (self.dp_axis_name,)),
202 jax.sharding.PartitionSpec(),
203 )
204 agent_state, ec_opt_state, workflow_metrics = jax.device_put(
205 (agent_state, ec_opt_state, workflow_metrics), sharding
206 )
207 key = split_key_to_devices(key, self.devices)
208
209 distributed_info = DistributedInfo(
210 rank=get_global_ranks(),
211 world_size=jax.device_count(),
212 )
213
214 state = State(
215 key=key,
216 agent_state=agent_state,
217 ec_opt_state=ec_opt_state,
218 metrics=workflow_metrics,
219 distributed_info=distributed_info,
220 )
221
222 state = self._postsetup(state)
223
224 return state
225
226 def _postsetup(self, state: State) -> State:
227 """Post-setup state before training.
228
229 By default, no post-setup is applied
230 """
231 return state
232
233 def _replace_actor_params(
234 self, agent_state: AgentState, params: Params
235 ) -> AgentState:
236 """Define how to replace the pop agent_state from the population params.
237
238 Args:
239 agent_state: State of the agent.
240 params: Population params.
241
242 Returns:
243 New agent_state with replaced population params.
244 """
245 raise NotImplementedError
246
247 def _update_obs_preprocessor(
248 self, agent_state: AgentState, trajectory: SampleBatch
249 ) -> AgentState:
250 """Update the obs_preprocessor_state based on sampled trajectories.
251
252 By default, don't update obs_preprocessor_state.
253
254 Args:
255 agent_state: State of the agent
256 trajectory: Episodic trajectory (T, B, ...)
257 """
258 return agent_state
259
260 def _metrics_to_fitnesses(self, metrics: MetricBase) -> chex.ArrayTree:
261 """Convert the rollout metrics to fitnesses.
262
263 By default, use the mean of episode_returns over multiple episodes as fitnesses.
264
265 Args:
266 metrics: Rollout metrics.
267
268 Returns:
269 Fitnesses of the population.
270 """
271 return jnp.mean(metrics.episode_returns, axis=-1)
272
[docs]
273 def step(self, state: State) -> tuple[MetricBase, State]:
274 agent_state = state.agent_state
275 key, rollout_key = jax.random.split(state.key, 2)
276
277 pop, ec_opt_state = self.ec_optimizer.ask(state.ec_opt_state)
278 pop_size = jtu.tree_leaves(pop)[0].shape[0]
279
280 slice_size = pop_size // state.distributed_info.world_size
281 eval_pop = jtu.tree_map(
282 lambda x: jax.lax.dynamic_slice_in_dim(
283 x, state.distributed_info.rank * slice_size, slice_size, axis=0
284 ),
285 pop,
286 )
287
288 pop_agent_state = self._replace_actor_params(agent_state, eval_pop)
289
290 if isinstance(self.ec_evaluator, EpisodeCollector):
291 # trajectory: [#pop, T, #episodes]
292 rollout_metrics, trajectory = jax.vmap(
293 self.ec_evaluator.rollout,
294 in_axes=(self.agent_state_vmap_axes, 0, None),
295 )(
296 pop_agent_state,
297 jax.random.split(rollout_key, num=slice_size),
298 self.config.episodes_for_fitness,
299 )
300 # [#pop, T, B, ...] -> [T, #pop*B, ...]
301 trajectory = flatten_pop_rollout_episode(trajectory)
302 trajectory = tree_stop_gradient(trajectory)
303 agent_state = self._update_obs_preprocessor(agent_state, trajectory)
304
305 elif isinstance(self.ec_evaluator, Evaluator):
306 rollout_metrics = jax.vmap(
307 self.ec_evaluator.evaluate,
308 in_axes=(self.agent_state_vmap_axes, 0, None),
309 )(
310 pop_agent_state,
311 jax.random.split(rollout_key, num=slice_size),
312 self.config.episodes_for_fitness,
313 )
314
315 fitnesses = self._metrics_to_fitnesses(rollout_metrics)
316 fitnesses = all_gather(fitnesses, self.dp_axis_name, axis=0, tiled=True)
317
318 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
319
320 sampled_episodes = psum(
321 jnp.uint32(pop_size * self.config.episodes_for_fitness),
322 self.dp_axis_name,
323 )
324 sampled_timesteps_m = (
325 psum(rollout_metrics.episode_lengths.sum(), self.dp_axis_name) / 1e6
326 )
327
328 workflow_metrics = state.metrics.replace(
329 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
330 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m,
331 iterations=state.metrics.iterations + 1,
332 best_objective=jnp.maximum(
333 state.metrics.best_objective, jnp.max(rollout_metrics.episode_returns)
334 ),
335 )
336
337 train_metrics = ECTrainMetric(
338 objectives=fitnesses,
339 ec_metrics=ec_metrics,
340 )
341
342 return train_metrics, state.replace(
343 key=key,
344 agent_state=agent_state,
345 ec_opt_state=ec_opt_state,
346 metrics=workflow_metrics,
347 )
348
[docs]
349 @classmethod
350 def enable_jit(cls) -> None:
351 super().enable_jit()
352 cls._postsetup = jax.jit(cls._postsetup, static_argnums=(0,))
353
[docs]
354 @classmethod
355 def enable_shmap(cls, axis_name) -> None:
356 super().enable_shmap(axis_name)
357 # Note: In the new paradigm, we use shmap_vmap where needed instead of wrappers here.
358 pass
359
360
[docs]
361class MultiObjectiveECWorkflowTemplate(ECWorkflowTemplate):
362 """Workflow template for multi-objective EC algorithms."""
363
364 def _metrics_to_fitnesses(self, metrics: MetricBase) -> chex.ArrayTree:
365 fitnesses = jnp.stack(
366 [jnp.mean(metrics[k], axis=-1) for k in self.config.metric_names], axis=-1
367 )
368 if fitnesses.shape[-1] == 1:
369 fitnesses = fitnesses.squeeze(-1)
370
371 return fitnesses
372
373 def _setup_workflow_metrics(self) -> MetricBase:
374 return MultiObjectiveECWorkflowMetric()
375
[docs]
376 def step(self, state: State) -> tuple[MetricBase, State]:
377 agent_state = state.agent_state
378 key, rollout_key = jax.random.split(state.key, 2)
379
380 pop, ec_opt_state = self.ec_optimizer.ask(state.ec_opt_state)
381 pop_size = jtu.tree_leaves(pop)[0].shape[0]
382
383 slice_size = pop_size // state.distributed_info.world_size
384 eval_pop = jtu.tree_map(
385 lambda x: jax.lax.dynamic_slice_in_dim(
386 x, state.distributed_info.rank * slice_size, slice_size, axis=0
387 ),
388 pop,
389 )
390
391 pop_agent_state = self._replace_actor_params(agent_state, eval_pop)
392
393 if isinstance(self.ec_evaluator, EpisodeCollector):
394 # trajectory: [#pop, T, #episodes]
395 rollout_metrics, trajectory = jax.vmap(
396 self.ec_evaluator.rollout,
397 in_axes=(self.agent_state_vmap_axes, 0, None),
398 )(
399 pop_agent_state,
400 jax.random.split(rollout_key, num=slice_size),
401 self.config.episodes_for_fitness,
402 )
403 # [#pop, T, B, ...] -> [T, #pop*B, ...]
404 trajectory = flatten_pop_rollout_episode(trajectory)
405 trajectory = tree_stop_gradient(trajectory)
406 agent_state = self._update_obs_preprocessor(agent_state, trajectory)
407
408 elif isinstance(self.ec_evaluator, Evaluator):
409 rollout_metrics = jax.vmap(
410 self.ec_evaluator.evaluate,
411 in_axes=(self.agent_state_vmap_axes, 0, None),
412 )(
413 pop_agent_state,
414 jax.random.split(rollout_key, num=slice_size),
415 self.config.episodes_for_fitness,
416 )
417
418 fitnesses = self._metrics_to_fitnesses(rollout_metrics)
419 fitnesses = all_gather(fitnesses, self.dp_axis_name, axis=0, tiled=True)
420
421 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
422
423 sampled_episodes = psum(
424 jnp.uint32(pop_size * self.config.episodes_for_fitness),
425 self.dp_axis_name,
426 )
427 sampled_timesteps_m = (
428 psum(rollout_metrics.episode_lengths.sum(), self.dp_axis_name) / 1e6
429 )
430
431 workflow_metrics = state.metrics.replace(
432 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
433 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m,
434 iterations=state.metrics.iterations + 1,
435 )
436
437 train_metrics = ECTrainMetric(objectives=fitnesses, ec_metrics=ec_metrics)
438
439 return train_metrics, state.replace(
440 key=key,
441 agent_state=agent_state,
442 ec_opt_state=ec_opt_state,
443 metrics=workflow_metrics,
444 )