1import copy
2import logging
3import math
4from functools import partial
5from omegaconf import DictConfig, OmegaConf, open_dict, read_write
6
7import chex
8import hydra
9import jax
10import jax.numpy as jnp
11import jax.tree_util as jtu
12from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
13from jax import shard_map
14
15
16from evorl.agent import RandomAgent
17from evorl.distributed import (
18 POP_AXIS_NAME,
19 shmap_vmap,
20 tree_device_put,
21)
22from evorl.rollout import rollout
23from evorl.metrics import MetricBase, EvaluateMetric
24from evorl.envs import AutoresetMode, create_env
25from evorl.evaluators import Evaluator
26from evorl.types import MISSING_REWARD, PyTreeDict, State, PyTreeData
27from evorl.replay_buffers import ReplayBufferState
28from evorl.recorders import get_1d_array_statistics, get_1d_array, add_prefix
29from evorl.utils.rl_toolkits import flatten_rollout_trajectory
30from evorl.utils.jax_utils import (
31 tree_get,
32 tree_set,
33 tree_stop_gradient,
34 scan_and_last,
35 is_jitted,
36)
37from evorl.utils import running_statistics
38from evorl.workflows import RLWorkflow, OffPolicyWorkflow, Workflow
39
40from .pbt_utils import convert_pop_to_df
41from .pbt_operations import explore, select
42from ..offpolicy_utils import clean_trajectory, skip_replay_buffer_state
43
44logger = logging.getLogger(__name__)
45
46
[docs]
47class PBTTrainMetric(MetricBase):
48 pop_episode_returns: chex.Array
49 pop_episode_lengths: chex.Array
50 pop_train_metrics: MetricBase
51 pop: chex.ArrayTree
52
53
[docs]
54class PBTOffpolicyTrainMetric(PBTTrainMetric):
55 rb_size: chex.Array
56
57
[docs]
58class PBTEvalMetric(MetricBase):
59 pop_episode_returns: chex.Array
60 pop_episode_lengths: chex.Array
61
62
[docs]
63class PBTWorkflowMetric(MetricBase):
64 # the average of sampled timesteps of all workflows
65 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32)
66 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
67
68
[docs]
69class PBTOptState(PyTreeData):
70 pass
71
72
[docs]
73class PBTWorkflowBase(Workflow):
74 def __init__(self, workflow: RLWorkflow, evaluator: Evaluator, config: DictConfig):
75 super().__init__(config)
76
77 self.workflow = workflow
78 self.evaluator = evaluator
79 self.devices = jax.local_devices()[:1]
80 self.sharding = None # training sharding
81
82 @classmethod
83 def _rescale_config(cls, config: DictConfig) -> None:
84 num_devices = jax.device_count()
85
86 if config.pop_size % num_devices != 0:
87 logger.warning(
88 f"pop_size({config.pop_size}) cannot be divided by num_devices({num_devices}), "
89 f"rescale pop_size to {config.pop_size // num_devices * num_devices}"
90 )
91
92 config.pop_size = (config.pop_size // num_devices) * num_devices
93
[docs]
94 @classmethod
95 def build_from_config(
96 cls, config: DictConfig, enable_multi_devices=True, enable_jit: bool = True
97 ):
98 config = copy.deepcopy(config) # avoid in-place modification
99
100 devices = jax.local_devices()
101
102 OmegaConf.set_readonly(config, False)
103 cls._rescale_config(config)
104
105 if enable_jit:
106 cls.enable_jit()
107
108 OmegaConf.set_readonly(config, True)
109
110 workflow = cls._build_from_config(config)
111
112 mesh = Mesh(devices, axis_names=(POP_AXIS_NAME,))
113 workflow.devices = devices
114 workflow.sharding = NamedSharding(mesh, P(POP_AXIS_NAME))
115
116 return workflow
117
118 @classmethod
119 def _build_from_config(cls, config: DictConfig):
120 target_workflow_config = config.target_workflow
121 target_workflow_config = copy.deepcopy(target_workflow_config)
122 target_workflow_cls = hydra.utils.get_class(target_workflow_config.workflow_cls)
123
124 devices = jax.local_devices()
125
126 with read_write(target_workflow_config):
127 with open_dict(target_workflow_config):
128 target_workflow_config.env = copy.deepcopy(config.env)
129 # disable target workflow ckpt
130 target_workflow_config.checkpoint = OmegaConf.create(dict(enable=False))
131
132 OmegaConf.set_readonly(target_workflow_config, True)
133
134 enable_jit = is_jitted(cls.step)
135 target_workflow = target_workflow_cls.build_from_config(
136 target_workflow_config, enable_jit=enable_jit
137 )
138
139 target_workflow.devices = devices
140
141 eval_env = create_env(
142 config.env,
143 episode_length=config.env.max_episode_steps,
144 parallel=config.num_eval_envs,
145 autoreset_mode=AutoresetMode.DISABLED,
146 )
147 evaluator = Evaluator(
148 env=eval_env,
149 action_fn=target_workflow.agent.evaluate_actions,
150 max_episode_steps=config.env.max_episode_steps,
151 )
152
153 return cls(target_workflow, evaluator, config)
154
155 def _setup_pop_and_pbt_optimizer(
156 self, key: chex.PRNGKey
157 ) -> tuple[chex.ArrayTree, PBTOptState]:
158 raise NotImplementedError
159
160 def _customize_optimizer(self) -> None:
161 pass
162
[docs]
163 def setup(self, key: chex.PRNGKey):
164 pop_size = self.config.pop_size
165 self._customize_optimizer()
166
167 key, workflow_key, pop_key = jax.random.split(key, num=3)
168
169 pop, pbt_opt_state = self._setup_pop_and_pbt_optimizer(pop_key)
170 pop = tree_device_put(pop, self.sharding)
171
172 workflow_metrics = PBTWorkflowMetric()
173 shared_sharding = NamedSharding(self.sharding.mesh, P())
174 key, workflow_metrics, pbt_opt_state = jax.device_put(
175 (key, workflow_metrics, pbt_opt_state), shared_sharding
176 )
177
178 workflow_keys = jax.random.split(workflow_key, pop_size)
179 workflow_keys = jax.device_put(workflow_keys, self.sharding)
180 pop_workflow_state = shmap_vmap(
181 self.workflow.setup,
182 mesh=self.sharding.mesh,
183 in_specs=self.sharding.spec,
184 out_specs=self.sharding.spec,
185 check_rep=False,
186 )(workflow_keys)
187
188 # Note: for obs_preprocessor, we assume pop_workflow_state.agent_state.obs_preprocessor_state
189 # is already same by initialization in self.workflow.setup(),
190 # so we don't need sync them here.
191 # Caution: for off-policy workflow with postsetup, this may not be true.
192
193 pop_workflow_state = shmap_vmap(
194 self.apply_hyperparams_to_workflow_state,
195 mesh=self.sharding.mesh,
196 in_specs=self.sharding.spec,
197 out_specs=self.sharding.spec,
198 check_rep=False,
199 )(pop_workflow_state, pop)
200
201 return State(
202 key=key, # shared
203 metrics=workflow_metrics, # shared
204 pop_workflow_state=pop_workflow_state, # across devices
205 pop=pop, # across devices
206 pbt_opt_state=pbt_opt_state, # shared
207 )
208
[docs]
209 def step(self, state: State) -> tuple[MetricBase, State]:
210 pop_workflow_state = state.pop_workflow_state
211 pop = state.pop
212 pbt_opt_state = state.pbt_opt_state
213
214 # ===== step ======
215 def _train_steps(pop_wf_state):
216 def _one_step(pop_wf_state, _):
217 train_metrics, pop_wf_state = jax.vmap(self.workflow.step)(pop_wf_state)
218 return pop_wf_state, train_metrics
219
220 pop_wf_state, train_metrics = scan_and_last(
221 _one_step, pop_wf_state, (), length=self.config.workflow_steps_per_iter
222 )
223
224 return train_metrics, pop_wf_state
225
226 train_steps_fn = shard_map(
227 _train_steps,
228 mesh=self.sharding.mesh,
229 in_specs=self.sharding.spec,
230 out_specs=self.sharding.spec,
231 check_rep=False,
232 )
233
234 pop_train_metrics, pop_workflow_state = train_steps_fn(pop_workflow_state)
235
236 # ===== eval ======
237 eval_fn = shmap_vmap(
238 self.workflow.evaluate,
239 mesh=self.sharding.mesh,
240 in_specs=self.sharding.spec,
241 out_specs=self.sharding.spec,
242 check_rep=False,
243 )
244
245 pop_eval_metrics, pop_workflow_state = eval_fn(pop_workflow_state)
246
247 # customize your pop metrics here
248 pop_episode_returns = pop_eval_metrics.episode_returns
249
250 # ===== warmup or exploit & explore ======
251 key, exploit_and_explore_key = jax.random.split(state.key)
252
253 def _dummy_fn(pbt_opt_state, pop, pop_workflow_state, pop_metrics, key):
254 return pop, pop_workflow_state, pbt_opt_state
255
256 pop, pop_workflow_state, pbt_opt_state = jax.lax.cond(
257 state.metrics.iterations + 1
258 <= math.ceil(
259 self.config.warmup_steps / self.config.workflow_steps_per_iter
260 ),
261 _dummy_fn,
262 self.exploit_and_explore,
263 pbt_opt_state,
264 pop,
265 pop_workflow_state,
266 pop_episode_returns,
267 exploit_and_explore_key,
268 )
269
270 # ===== record metrics ======
271 if hasattr(pop_workflow_state.metrics, "sampled_timesteps"):
272 sampled_timesteps_m = jnp.sum(
273 pop_workflow_state.metrics.sampled_timesteps / 1e6
274 )
275 elif hasattr(pop_workflow_state.metrics, "sampled_timesteps_m"):
276 sampled_timesteps_m = jnp.sum(
277 pop_workflow_state.metrics.sampled_timesteps_m
278 )
279 else:
280 sampled_timesteps_m = jnp.zeros((), dtype=jnp.float32)
281
282 # Note: sampled_timesteps_m is already accumulated in target_workflow
283 workflow_metrics = state.metrics.replace(
284 sampled_timesteps_m=sampled_timesteps_m,
285 iterations=state.metrics.iterations + 1,
286 )
287
288 train_metrics = PBTTrainMetric(
289 pop_episode_returns=pop_eval_metrics.episode_returns,
290 pop_episode_lengths=pop_eval_metrics.episode_lengths,
291 pop_train_metrics=pop_train_metrics,
292 pop=state.pop, # save prev pop instead of new pop to match the metrics
293 )
294
295 return train_metrics, state.replace(
296 key=key,
297 metrics=workflow_metrics,
298 pop=pop,
299 pop_workflow_state=pop_workflow_state,
300 pbt_opt_state=pbt_opt_state,
301 )
302
[docs]
303 def evaluate(self, state: State) -> State:
304 key, eval_key = jax.random.split(state.key, num=2)
305
306 def _evaluate(wf_state, key):
307 # [#episodes]
308 raw_eval_metrics = self.evaluator.evaluate(
309 wf_state.agent_state, key, num_episodes=self.config.eval_episodes
310 )
311
312 eval_metrics = EvaluateMetric(
313 episode_returns=raw_eval_metrics.episode_returns.mean(),
314 episode_lengths=raw_eval_metrics.episode_lengths.mean(),
315 )
316 return eval_metrics
317
318 eval_fn = shmap_vmap(
319 _evaluate,
320 mesh=self.sharding.mesh,
321 in_specs=self.sharding.spec,
322 out_specs=self.sharding.spec,
323 check_rep=False,
324 )
325
326 pop_eval_metrics = eval_fn(
327 state.pop_workflow_state, jax.random.split(eval_key, self.config.pop_size)
328 )
329
330 eval_metrics = PBTEvalMetric(
331 pop_episode_returns=pop_eval_metrics.episode_returns,
332 pop_episode_lengths=pop_eval_metrics.episode_lengths,
333 )
334
335 return eval_metrics, state.replace(key=key)
336
[docs]
337 def exploit_and_explore(
338 self,
339 pbt_opt_state: PBTOptState,
340 pop: chex.ArrayTree,
341 pop_workflow_state: State,
342 pop_metrics: chex.ArrayTree,
343 key: chex.PRNGKey,
344 ) -> tuple[chex.ArrayTree, State, PBTOptState]:
345 raise NotImplementedError
346
[docs]
347 def apply_hyperparams_to_workflow_state(
348 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric]
349 ) -> State:
350 raise NotImplementedError
351
[docs]
352 @classmethod
353 def enable_jit(cls) -> None:
354 cls.setup = jax.jit(cls.setup, static_argnums=(0,))
355 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
356 cls.step = jax.jit(cls.step, static_argnums=(0,))
357
358
[docs]
359class PBTWorkflowTemplate(PBTWorkflowBase):
360 """Standard PBT Workflow Template."""
361
[docs]
362 def exploit_and_explore(
363 self,
364 pbt_opt_state: PBTOptState,
365 pop: chex.ArrayTree,
366 pop_workflow_state: State,
367 pop_metrics: chex.ArrayTree,
368 key: chex.PRNGKey,
369 ) -> tuple[chex.ArrayTree, State, PBTOptState]:
370 exploit_key, explore_key = jax.random.split(key)
371
372 config = self.config
373
374 top_indices, bottom_indices = select(
375 pop_metrics, # using episode_return
376 exploit_key,
377 bottoms_num=round(config.pop_size * config.bottom_ratio),
378 tops_num=round(config.pop_size * config.top_ratio),
379 )
380
381 parents = tree_get(pop, top_indices)
382 parents_wf_state = tree_get(pop_workflow_state, top_indices)
383
384 offsprings = jax.vmap(
385 partial(
386 explore,
387 perturb_factor=config.perturb_factor,
388 search_space=config.search_space,
389 )
390 )(parents, jax.random.split(explore_key, bottom_indices.shape[0]))
391
392 # Note: no need to deepcopy parents_wf_state here, since it should be
393 # ensured immutable in apply_hyperparams_to_workflow_state()
394 offsprings_workflow_state = jax.vmap(self.apply_hyperparams_to_workflow_state)(
395 parents_wf_state, offsprings
396 )
397
398 # ==== survival | merge population ====
399 pop = tree_set(pop, offsprings, bottom_indices, unique_indices=True)
400 # we copy wf_state back to offspring wf_state
401 pop_workflow_state = tree_set(
402 pop_workflow_state,
403 offsprings_workflow_state,
404 bottom_indices,
405 unique_indices=True,
406 )
407
408 return pop, pop_workflow_state, pbt_opt_state
409
410 def _record_step_metrics(self, train_metrics, workflow_metrics, iters):
411 train_metrics_dict = train_metrics.to_local_dict()
412
413 pop_train_metric = train_metrics_dict["pop_train_metrics"]
414 if "train_episode_return" in pop_train_metric:
415 train_episode_return = pop_train_metric["train_episode_return"]
416 # Note: the order does not matter, since we use
417 train_episode_return = train_episode_return[
418 train_episode_return != MISSING_REWARD
419 ]
420
421 if len(train_episode_return) == 0:
422 train_episode_return = None
423
424 pop_train_metric["train_episode_return"] = train_episode_return
425
426 train_metrics_dict.update(
427 pop_episode_returns=get_1d_array_statistics(
428 train_metrics_dict["pop_episode_returns"], histogram=True
429 ),
430 pop_episode_lengths=get_1d_array_statistics(
431 train_metrics_dict["pop_episode_lengths"], histogram=True
432 ),
433 pop=convert_pop_to_df(train_metrics_dict["pop"]),
434 pop_train_metrics=jtu.tree_map(get_1d_array_statistics, pop_train_metric),
435 )
436
437 self.recorder.write(workflow_metrics.to_local_dict(), iters)
438 self.recorder.write(train_metrics_dict, iters)
439
[docs]
440 def learn(self, state: State) -> State:
441 for i in range(state.metrics.iterations, self.config.num_iters):
442 iters = i + 1
443 train_metrics, state = self.step(state)
444 workflow_metrics = state.metrics
445
446 self._record_step_metrics(train_metrics, workflow_metrics, iters)
447
448 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters:
449 eval_metrics, state = self.evaluate(state)
450
451 eval_metrics_dict = jtu.tree_map(
452 get_1d_array,
453 eval_metrics.to_local_dict(),
454 )
455
456 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
457
458 self.checkpoint_manager.save(
459 iters,
460 state,
461 force=iters == self.config.num_iters,
462 )
463
464 return state
465
466
[docs]
467class PBTOffpolicyWorkflowTemplate(PBTWorkflowTemplate):
468 """PBT Workflow Template for Off-policy algorithms with shared replay buffer."""
469
470 def __init__(
471 self, workflow: OffPolicyWorkflow, evaluator: Evaluator, config: DictConfig
472 ):
473 super().__init__(workflow, evaluator, config)
474 self.replay_buffer = workflow.replay_buffer
475
[docs]
476 def setup(self, key: chex.PRNGKey):
477 key, rb_key = jax.random.split(key)
478 state = super().setup(key)
479
480 state = state.replace(
481 replay_buffer_state=self._setup_replaybuffer(rb_key),
482 )
483
484 logger.info("Start replay buffer post-setup")
485 state = self._postsetup_replaybuffer(state)
486
487 logger.info("Complete replay buffer post-setup")
488
489 return state
490
491 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState:
492 # replicas across devices: every device needs one replay_buffer_state
493 replay_buffer_state = shard_map(
494 self.workflow._setup_replaybuffer,
495 mesh=self.sharding.mesh,
496 in_specs=P(),
497 out_specs=P(),
498 check_rep=False,
499 )(key)
500
501 return replay_buffer_state
502
503 def _postsetup_replaybuffer(self, state: State) -> State:
504 # Since the replay buffer is shared across workflows, we need an independent post-setup
505 env = self.workflow.env
506 action_space = env.action_space
507 obs_space = env.obs_space
508 num_envs = self.config.target_workflow.num_envs
509
510 pop_workflow_state = state.pop_workflow_state
511 replay_buffer_state = state.replay_buffer_state
512
513 rollout_length = self.config.random_timesteps // num_envs
514
515 # ==== fill random transitions ====
516
517 key, env_key, rollout_key = jax.random.split(state.key, num=3)
518 shared_sharding = NamedSharding(self.sharding.mesh, P())
519
520 random_agent = RandomAgent()
521 random_agent_state = random_agent.init(
522 obs_space, action_space, jax.random.PRNGKey(0)
523 )
524 env_state = env.reset(env_key)
525
526 trajectory, env_state = rollout(
527 env_fn=env.step,
528 action_fn=random_agent.compute_actions,
529 env_state=env_state,
530 agent_state=random_agent_state,
531 key=rollout_key,
532 rollout_length=rollout_length,
533 env_extra_fields=("ori_obs", "termination"),
534 )
535
536 # [T, B, ...] -> [T*B, ...]
537 trajectory = clean_trajectory(trajectory)
538 trajectory = flatten_rollout_trajectory(trajectory)
539 trajectory = tree_stop_gradient(trajectory)
540 trajectory = jax.device_put(trajectory, shared_sharding)
541
542 if pop_workflow_state.agent_state.obs_preprocessor_state is not None:
543 # update all obs_preprocessor_state by the random trajectory
544
545 obs_preprocessor_state = (
546 pop_workflow_state.agent_state.obs_preprocessor_state
547 )
548
549 obs_preprocessor_state = shmap_vmap(
550 running_statistics.update,
551 mesh=self.sharding.mesh,
552 in_specs=(shared_sharding.spec, P()),
553 out_specs=self.sharding.spec,
554 check_rep=False,
555 )(obs_preprocessor_state, trajectory.obs)
556
557 pop_workflow_state = pop_workflow_state.replace(
558 agent_state=pop_workflow_state.agent_state.replace(
559 obs_preprocessor_state=obs_preprocessor_state
560 )
561 )
562
563 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
564
565 sampled_timesteps_m = rollout_length * num_envs / 1e6
566 workflow_metrics = state.metrics.replace(
567 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m,
568 )
569
570 return state.replace(
571 key=key,
572 metrics=workflow_metrics,
573 replay_buffer_state=replay_buffer_state,
574 )
575
[docs]
576 def step(self, state: State) -> tuple[MetricBase, State]:
577 pop_workflow_state = state.pop_workflow_state
578 pop = state.pop
579 replay_buffer_state = state.replay_buffer_state
580 pbt_opt_state = state.pbt_opt_state
581
582 # ===== step ======
583 def _train_steps(pop_wf_state, replay_buffer_state):
584 def _one_step(carry, _):
585 pop_wf_state, replay_buffer_state = carry
586
587 def _wf_step_wrapper(wf_state):
588 wf_state = wf_state.replace(replay_buffer_state=replay_buffer_state)
589 train_metrics, wf_state = self.workflow.step(wf_state)
590 wf_state = wf_state.replace(replay_buffer_state=None)
591 return train_metrics, wf_state
592
593 pop_train_metrics, pop_wf_state = jax.vmap(_wf_step_wrapper)(
594 pop_wf_state
595 )
596
597 # add replay buffer data:
598 # [pop, T*B, ...] -> [pop*T*B, ...]
599 trajectory = jtu.tree_map(
600 lambda x: jax.lax.collapse(x, 0, 2), pop_train_metrics.trajectory
601 )
602 trajectory = jax.lax.all_gather(
603 trajectory, POP_AXIS_NAME, axis=0, tiled=True
604 )
605
606 replay_buffer_state = self.replay_buffer.add(
607 replay_buffer_state, trajectory
608 )
609 pop_train_metrics = pop_train_metrics.replace(trajectory=None)
610
611 return (pop_wf_state, replay_buffer_state), pop_train_metrics
612
613 (pop_wf_state, replay_buffer_state), train_metrics = scan_and_last(
614 _one_step,
615 (pop_wf_state, replay_buffer_state),
616 (),
617 length=self.config.workflow_steps_per_iter,
618 )
619
620 return train_metrics, pop_wf_state, replay_buffer_state
621
622 pop_train_metrics, pop_workflow_state, replay_buffer_state = shard_map(
623 _train_steps,
624 mesh=self.sharding.mesh,
625 in_specs=(P(POP_AXIS_NAME), P()),
626 out_specs=(P(POP_AXIS_NAME), P(POP_AXIS_NAME), P()),
627 check_rep=False,
628 )(pop_workflow_state, replay_buffer_state)
629
630 # ===== eval ======
631 eval_fn = shmap_vmap(
632 self.workflow.evaluate,
633 mesh=self.sharding.mesh,
634 in_specs=self.sharding.spec,
635 out_specs=self.sharding.spec,
636 check_rep=False,
637 )
638
639 pop_eval_metrics, pop_workflow_state = eval_fn(pop_workflow_state)
640
641 # customize your pop metrics here
642 pop_episode_returns = pop_eval_metrics.episode_returns
643
644 # ===== warmup or exploit & explore ======
645 key, exploit_and_explore_key = jax.random.split(state.key)
646
647 def _dummy_fn(pbt_opt_state, pop, pop_workflow_state, pop_metrics, key):
648 return pop, pop_workflow_state, pbt_opt_state
649
650 pop, pop_workflow_state, pbt_opt_state = jax.lax.cond(
651 state.metrics.iterations + 1
652 <= math.ceil(
653 self.config.warmup_steps / self.config.workflow_steps_per_iter
654 ),
655 _dummy_fn,
656 self.exploit_and_explore,
657 pbt_opt_state,
658 pop,
659 pop_workflow_state,
660 pop_episode_returns,
661 exploit_and_explore_key,
662 )
663
664 # ===== record metrics ======
665 workflow_metrics = state.metrics.replace(
666 sampled_timesteps_m=jnp.sum(
667 pop_workflow_state.metrics.sampled_timesteps / 1e6
668 ), # convert uint32 to float32
669 iterations=state.metrics.iterations + 1,
670 )
671
672 train_metrics = PBTOffpolicyTrainMetric(
673 pop_episode_returns=pop_eval_metrics.episode_returns,
674 pop_episode_lengths=pop_eval_metrics.episode_lengths,
675 pop_train_metrics=pop_train_metrics,
676 pop=state.pop, # save prev pop instead of new pop to match the metrics
677 rb_size=replay_buffer_state.buffer_size,
678 )
679
680 return train_metrics, state.replace(
681 key=key,
682 metrics=workflow_metrics,
683 pop=pop,
684 pop_workflow_state=pop_workflow_state,
685 pbt_opt_state=pbt_opt_state,
686 replay_buffer_state=replay_buffer_state,
687 )
688
689 def _record_step_metrics(self, train_metrics, workflow_metrics, iters):
690 train_metrics_dict = train_metrics.to_local_dict()
691
692 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
693 train_metrics_dict["pop_episode_returns"], histogram=True
694 )
695 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
696 train_metrics_dict["pop_episode_lengths"], histogram=True
697 )
698 train_metrics_dict["pop"] = convert_pop_to_df(train_metrics_dict["pop"])
699 train_metrics_dict["pop_train_metrics"] = jtu.tree_map(
700 get_1d_array_statistics, train_metrics_dict["pop_train_metrics"]
701 )
702
703 self.recorder.write(workflow_metrics.to_local_dict(), iters)
704 self.recorder.write(train_metrics_dict, iters)
705
[docs]
706 def learn(self, state: State) -> State:
707 for i in range(state.metrics.iterations, self.config.num_iters):
708 iters = i + 1
709 train_metrics, state = self.step(state)
710 workflow_metrics = state.metrics
711
712 self._record_step_metrics(train_metrics, workflow_metrics, iters)
713
714 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters:
715 eval_metrics, state = self.evaluate(state)
716
717 eval_metrics_dict = jtu.tree_map(
718 get_1d_array,
719 eval_metrics.to_local_dict(),
720 )
721
722 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
723
724 saved_state = state
725 if not self.config.save_replay_buffer:
726 saved_state = skip_replay_buffer_state(saved_state)
727 self.checkpoint_manager.save(
728 iters,
729 saved_state,
730 force=iters == self.config.num_iters,
731 )
732
733 return state
734
[docs]
735 @classmethod
736 def enable_jit(cls) -> None:
737 super().enable_jit()
738 cls._postsetup_replaybuffer = jax.jit(
739 cls._postsetup_replaybuffer, static_argnums=(0,)
740 )