Source code for evorl.algorithms.meta.pbt_workflow

  1import copy
  2import logging
  3import math
  4from functools import partial
  5from omegaconf import DictConfig, OmegaConf, open_dict, read_write
  6
  7import chex
  8import hydra
  9import jax
 10import jax.numpy as jnp
 11import jax.tree_util as jtu
 12from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
 13from jax import shard_map
 14
 15
 16from evorl.agent import RandomAgent
 17from evorl.distributed import (
 18    POP_AXIS_NAME,
 19    shmap_vmap,
 20    tree_device_put,
 21)
 22from evorl.rollout import rollout
 23from evorl.metrics import MetricBase, EvaluateMetric
 24from evorl.envs import AutoresetMode, create_env
 25from evorl.evaluators import Evaluator
 26from evorl.types import MISSING_REWARD, PyTreeDict, State, PyTreeData
 27from evorl.replay_buffers import ReplayBufferState
 28from evorl.recorders import get_1d_array_statistics, get_1d_array, add_prefix
 29from evorl.utils.rl_toolkits import flatten_rollout_trajectory
 30from evorl.utils.jax_utils import (
 31    tree_get,
 32    tree_set,
 33    tree_stop_gradient,
 34    scan_and_last,
 35    is_jitted,
 36)
 37from evorl.utils import running_statistics
 38from evorl.workflows import RLWorkflow, OffPolicyWorkflow, Workflow
 39
 40from .pbt_utils import convert_pop_to_df
 41from .pbt_operations import explore, select
 42from ..offpolicy_utils import clean_trajectory, skip_replay_buffer_state
 43
 44logger = logging.getLogger(__name__)
 45
 46
[docs] 47class PBTTrainMetric(MetricBase): 48 pop_episode_returns: chex.Array 49 pop_episode_lengths: chex.Array 50 pop_train_metrics: MetricBase 51 pop: chex.ArrayTree
52 53
[docs] 54class PBTOffpolicyTrainMetric(PBTTrainMetric): 55 rb_size: chex.Array
56 57
[docs] 58class PBTEvalMetric(MetricBase): 59 pop_episode_returns: chex.Array 60 pop_episode_lengths: chex.Array
61 62
[docs] 63class PBTWorkflowMetric(MetricBase): 64 # the average of sampled timesteps of all workflows 65 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32) 66 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
67 68
[docs] 69class PBTOptState(PyTreeData): 70 pass
71 72
[docs] 73class PBTWorkflowBase(Workflow): 74 def __init__(self, workflow: RLWorkflow, evaluator: Evaluator, config: DictConfig): 75 super().__init__(config) 76 77 self.workflow = workflow 78 self.evaluator = evaluator 79 self.devices = jax.local_devices()[:1] 80 self.sharding = None # training sharding 81 82 @classmethod 83 def _rescale_config(cls, config: DictConfig) -> None: 84 num_devices = jax.device_count() 85 86 if config.pop_size % num_devices != 0: 87 logger.warning( 88 f"pop_size({config.pop_size}) cannot be divided by num_devices({num_devices}), " 89 f"rescale pop_size to {config.pop_size // num_devices * num_devices}" 90 ) 91 92 config.pop_size = (config.pop_size // num_devices) * num_devices 93
[docs] 94 @classmethod 95 def build_from_config( 96 cls, config: DictConfig, enable_multi_devices=True, enable_jit: bool = True 97 ): 98 config = copy.deepcopy(config) # avoid in-place modification 99 100 devices = jax.local_devices() 101 102 OmegaConf.set_readonly(config, False) 103 cls._rescale_config(config) 104 105 if enable_jit: 106 cls.enable_jit() 107 108 OmegaConf.set_readonly(config, True) 109 110 workflow = cls._build_from_config(config) 111 112 mesh = Mesh(devices, axis_names=(POP_AXIS_NAME,)) 113 workflow.devices = devices 114 workflow.sharding = NamedSharding(mesh, P(POP_AXIS_NAME)) 115 116 return workflow
117 118 @classmethod 119 def _build_from_config(cls, config: DictConfig): 120 target_workflow_config = config.target_workflow 121 target_workflow_config = copy.deepcopy(target_workflow_config) 122 target_workflow_cls = hydra.utils.get_class(target_workflow_config.workflow_cls) 123 124 devices = jax.local_devices() 125 126 with read_write(target_workflow_config): 127 with open_dict(target_workflow_config): 128 target_workflow_config.env = copy.deepcopy(config.env) 129 # disable target workflow ckpt 130 target_workflow_config.checkpoint = OmegaConf.create(dict(enable=False)) 131 132 OmegaConf.set_readonly(target_workflow_config, True) 133 134 enable_jit = is_jitted(cls.step) 135 target_workflow = target_workflow_cls.build_from_config( 136 target_workflow_config, enable_jit=enable_jit 137 ) 138 139 target_workflow.devices = devices 140 141 eval_env = create_env( 142 config.env, 143 episode_length=config.env.max_episode_steps, 144 parallel=config.num_eval_envs, 145 autoreset_mode=AutoresetMode.DISABLED, 146 ) 147 evaluator = Evaluator( 148 env=eval_env, 149 action_fn=target_workflow.agent.evaluate_actions, 150 max_episode_steps=config.env.max_episode_steps, 151 ) 152 153 return cls(target_workflow, evaluator, config) 154 155 def _setup_pop_and_pbt_optimizer( 156 self, key: chex.PRNGKey 157 ) -> tuple[chex.ArrayTree, PBTOptState]: 158 raise NotImplementedError 159 160 def _customize_optimizer(self) -> None: 161 pass 162
[docs] 163 def setup(self, key: chex.PRNGKey): 164 pop_size = self.config.pop_size 165 self._customize_optimizer() 166 167 key, workflow_key, pop_key = jax.random.split(key, num=3) 168 169 pop, pbt_opt_state = self._setup_pop_and_pbt_optimizer(pop_key) 170 pop = tree_device_put(pop, self.sharding) 171 172 workflow_metrics = PBTWorkflowMetric() 173 shared_sharding = NamedSharding(self.sharding.mesh, P()) 174 key, workflow_metrics, pbt_opt_state = jax.device_put( 175 (key, workflow_metrics, pbt_opt_state), shared_sharding 176 ) 177 178 workflow_keys = jax.random.split(workflow_key, pop_size) 179 workflow_keys = jax.device_put(workflow_keys, self.sharding) 180 pop_workflow_state = shmap_vmap( 181 self.workflow.setup, 182 mesh=self.sharding.mesh, 183 in_specs=self.sharding.spec, 184 out_specs=self.sharding.spec, 185 check_rep=False, 186 )(workflow_keys) 187 188 # Note: for obs_preprocessor, we assume pop_workflow_state.agent_state.obs_preprocessor_state 189 # is already same by initialization in self.workflow.setup(), 190 # so we don't need sync them here. 191 # Caution: for off-policy workflow with postsetup, this may not be true. 192 193 pop_workflow_state = shmap_vmap( 194 self.apply_hyperparams_to_workflow_state, 195 mesh=self.sharding.mesh, 196 in_specs=self.sharding.spec, 197 out_specs=self.sharding.spec, 198 check_rep=False, 199 )(pop_workflow_state, pop) 200 201 return State( 202 key=key, # shared 203 metrics=workflow_metrics, # shared 204 pop_workflow_state=pop_workflow_state, # across devices 205 pop=pop, # across devices 206 pbt_opt_state=pbt_opt_state, # shared 207 )
208
[docs] 209 def step(self, state: State) -> tuple[MetricBase, State]: 210 pop_workflow_state = state.pop_workflow_state 211 pop = state.pop 212 pbt_opt_state = state.pbt_opt_state 213 214 # ===== step ====== 215 def _train_steps(pop_wf_state): 216 def _one_step(pop_wf_state, _): 217 train_metrics, pop_wf_state = jax.vmap(self.workflow.step)(pop_wf_state) 218 return pop_wf_state, train_metrics 219 220 pop_wf_state, train_metrics = scan_and_last( 221 _one_step, pop_wf_state, (), length=self.config.workflow_steps_per_iter 222 ) 223 224 return train_metrics, pop_wf_state 225 226 train_steps_fn = shard_map( 227 _train_steps, 228 mesh=self.sharding.mesh, 229 in_specs=self.sharding.spec, 230 out_specs=self.sharding.spec, 231 check_rep=False, 232 ) 233 234 pop_train_metrics, pop_workflow_state = train_steps_fn(pop_workflow_state) 235 236 # ===== eval ====== 237 eval_fn = shmap_vmap( 238 self.workflow.evaluate, 239 mesh=self.sharding.mesh, 240 in_specs=self.sharding.spec, 241 out_specs=self.sharding.spec, 242 check_rep=False, 243 ) 244 245 pop_eval_metrics, pop_workflow_state = eval_fn(pop_workflow_state) 246 247 # customize your pop metrics here 248 pop_episode_returns = pop_eval_metrics.episode_returns 249 250 # ===== warmup or exploit & explore ====== 251 key, exploit_and_explore_key = jax.random.split(state.key) 252 253 def _dummy_fn(pbt_opt_state, pop, pop_workflow_state, pop_metrics, key): 254 return pop, pop_workflow_state, pbt_opt_state 255 256 pop, pop_workflow_state, pbt_opt_state = jax.lax.cond( 257 state.metrics.iterations + 1 258 <= math.ceil( 259 self.config.warmup_steps / self.config.workflow_steps_per_iter 260 ), 261 _dummy_fn, 262 self.exploit_and_explore, 263 pbt_opt_state, 264 pop, 265 pop_workflow_state, 266 pop_episode_returns, 267 exploit_and_explore_key, 268 ) 269 270 # ===== record metrics ====== 271 if hasattr(pop_workflow_state.metrics, "sampled_timesteps"): 272 sampled_timesteps_m = jnp.sum( 273 pop_workflow_state.metrics.sampled_timesteps / 1e6 274 ) 275 elif hasattr(pop_workflow_state.metrics, "sampled_timesteps_m"): 276 sampled_timesteps_m = jnp.sum( 277 pop_workflow_state.metrics.sampled_timesteps_m 278 ) 279 else: 280 sampled_timesteps_m = jnp.zeros((), dtype=jnp.float32) 281 282 # Note: sampled_timesteps_m is already accumulated in target_workflow 283 workflow_metrics = state.metrics.replace( 284 sampled_timesteps_m=sampled_timesteps_m, 285 iterations=state.metrics.iterations + 1, 286 ) 287 288 train_metrics = PBTTrainMetric( 289 pop_episode_returns=pop_eval_metrics.episode_returns, 290 pop_episode_lengths=pop_eval_metrics.episode_lengths, 291 pop_train_metrics=pop_train_metrics, 292 pop=state.pop, # save prev pop instead of new pop to match the metrics 293 ) 294 295 return train_metrics, state.replace( 296 key=key, 297 metrics=workflow_metrics, 298 pop=pop, 299 pop_workflow_state=pop_workflow_state, 300 pbt_opt_state=pbt_opt_state, 301 )
302
[docs] 303 def evaluate(self, state: State) -> State: 304 key, eval_key = jax.random.split(state.key, num=2) 305 306 def _evaluate(wf_state, key): 307 # [#episodes] 308 raw_eval_metrics = self.evaluator.evaluate( 309 wf_state.agent_state, key, num_episodes=self.config.eval_episodes 310 ) 311 312 eval_metrics = EvaluateMetric( 313 episode_returns=raw_eval_metrics.episode_returns.mean(), 314 episode_lengths=raw_eval_metrics.episode_lengths.mean(), 315 ) 316 return eval_metrics 317 318 eval_fn = shmap_vmap( 319 _evaluate, 320 mesh=self.sharding.mesh, 321 in_specs=self.sharding.spec, 322 out_specs=self.sharding.spec, 323 check_rep=False, 324 ) 325 326 pop_eval_metrics = eval_fn( 327 state.pop_workflow_state, jax.random.split(eval_key, self.config.pop_size) 328 ) 329 330 eval_metrics = PBTEvalMetric( 331 pop_episode_returns=pop_eval_metrics.episode_returns, 332 pop_episode_lengths=pop_eval_metrics.episode_lengths, 333 ) 334 335 return eval_metrics, state.replace(key=key)
336
[docs] 337 def exploit_and_explore( 338 self, 339 pbt_opt_state: PBTOptState, 340 pop: chex.ArrayTree, 341 pop_workflow_state: State, 342 pop_metrics: chex.ArrayTree, 343 key: chex.PRNGKey, 344 ) -> tuple[chex.ArrayTree, State, PBTOptState]: 345 raise NotImplementedError
346
[docs] 347 def apply_hyperparams_to_workflow_state( 348 self, workflow_state: State, hyperparams: PyTreeDict[str, chex.Numeric] 349 ) -> State: 350 raise NotImplementedError
351
[docs] 352 @classmethod 353 def enable_jit(cls) -> None: 354 cls.setup = jax.jit(cls.setup, static_argnums=(0,)) 355 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 356 cls.step = jax.jit(cls.step, static_argnums=(0,))
357 358
[docs] 359class PBTWorkflowTemplate(PBTWorkflowBase): 360 """Standard PBT Workflow Template.""" 361
[docs] 362 def exploit_and_explore( 363 self, 364 pbt_opt_state: PBTOptState, 365 pop: chex.ArrayTree, 366 pop_workflow_state: State, 367 pop_metrics: chex.ArrayTree, 368 key: chex.PRNGKey, 369 ) -> tuple[chex.ArrayTree, State, PBTOptState]: 370 exploit_key, explore_key = jax.random.split(key) 371 372 config = self.config 373 374 top_indices, bottom_indices = select( 375 pop_metrics, # using episode_return 376 exploit_key, 377 bottoms_num=round(config.pop_size * config.bottom_ratio), 378 tops_num=round(config.pop_size * config.top_ratio), 379 ) 380 381 parents = tree_get(pop, top_indices) 382 parents_wf_state = tree_get(pop_workflow_state, top_indices) 383 384 offsprings = jax.vmap( 385 partial( 386 explore, 387 perturb_factor=config.perturb_factor, 388 search_space=config.search_space, 389 ) 390 )(parents, jax.random.split(explore_key, bottom_indices.shape[0])) 391 392 # Note: no need to deepcopy parents_wf_state here, since it should be 393 # ensured immutable in apply_hyperparams_to_workflow_state() 394 offsprings_workflow_state = jax.vmap(self.apply_hyperparams_to_workflow_state)( 395 parents_wf_state, offsprings 396 ) 397 398 # ==== survival | merge population ==== 399 pop = tree_set(pop, offsprings, bottom_indices, unique_indices=True) 400 # we copy wf_state back to offspring wf_state 401 pop_workflow_state = tree_set( 402 pop_workflow_state, 403 offsprings_workflow_state, 404 bottom_indices, 405 unique_indices=True, 406 ) 407 408 return pop, pop_workflow_state, pbt_opt_state
409 410 def _record_step_metrics(self, train_metrics, workflow_metrics, iters): 411 train_metrics_dict = train_metrics.to_local_dict() 412 413 pop_train_metric = train_metrics_dict["pop_train_metrics"] 414 if "train_episode_return" in pop_train_metric: 415 train_episode_return = pop_train_metric["train_episode_return"] 416 # Note: the order does not matter, since we use 417 train_episode_return = train_episode_return[ 418 train_episode_return != MISSING_REWARD 419 ] 420 421 if len(train_episode_return) == 0: 422 train_episode_return = None 423 424 pop_train_metric["train_episode_return"] = train_episode_return 425 426 train_metrics_dict.update( 427 pop_episode_returns=get_1d_array_statistics( 428 train_metrics_dict["pop_episode_returns"], histogram=True 429 ), 430 pop_episode_lengths=get_1d_array_statistics( 431 train_metrics_dict["pop_episode_lengths"], histogram=True 432 ), 433 pop=convert_pop_to_df(train_metrics_dict["pop"]), 434 pop_train_metrics=jtu.tree_map(get_1d_array_statistics, pop_train_metric), 435 ) 436 437 self.recorder.write(workflow_metrics.to_local_dict(), iters) 438 self.recorder.write(train_metrics_dict, iters) 439
[docs] 440 def learn(self, state: State) -> State: 441 for i in range(state.metrics.iterations, self.config.num_iters): 442 iters = i + 1 443 train_metrics, state = self.step(state) 444 workflow_metrics = state.metrics 445 446 self._record_step_metrics(train_metrics, workflow_metrics, iters) 447 448 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters: 449 eval_metrics, state = self.evaluate(state) 450 451 eval_metrics_dict = jtu.tree_map( 452 get_1d_array, 453 eval_metrics.to_local_dict(), 454 ) 455 456 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 457 458 self.checkpoint_manager.save( 459 iters, 460 state, 461 force=iters == self.config.num_iters, 462 ) 463 464 return state
465 466
[docs] 467class PBTOffpolicyWorkflowTemplate(PBTWorkflowTemplate): 468 """PBT Workflow Template for Off-policy algorithms with shared replay buffer.""" 469 470 def __init__( 471 self, workflow: OffPolicyWorkflow, evaluator: Evaluator, config: DictConfig 472 ): 473 super().__init__(workflow, evaluator, config) 474 self.replay_buffer = workflow.replay_buffer 475
[docs] 476 def setup(self, key: chex.PRNGKey): 477 key, rb_key = jax.random.split(key) 478 state = super().setup(key) 479 480 state = state.replace( 481 replay_buffer_state=self._setup_replaybuffer(rb_key), 482 ) 483 484 logger.info("Start replay buffer post-setup") 485 state = self._postsetup_replaybuffer(state) 486 487 logger.info("Complete replay buffer post-setup") 488 489 return state
490 491 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState: 492 # replicas across devices: every device needs one replay_buffer_state 493 replay_buffer_state = shard_map( 494 self.workflow._setup_replaybuffer, 495 mesh=self.sharding.mesh, 496 in_specs=P(), 497 out_specs=P(), 498 check_rep=False, 499 )(key) 500 501 return replay_buffer_state 502 503 def _postsetup_replaybuffer(self, state: State) -> State: 504 # Since the replay buffer is shared across workflows, we need an independent post-setup 505 env = self.workflow.env 506 action_space = env.action_space 507 obs_space = env.obs_space 508 num_envs = self.config.target_workflow.num_envs 509 510 pop_workflow_state = state.pop_workflow_state 511 replay_buffer_state = state.replay_buffer_state 512 513 rollout_length = self.config.random_timesteps // num_envs 514 515 # ==== fill random transitions ==== 516 517 key, env_key, rollout_key = jax.random.split(state.key, num=3) 518 shared_sharding = NamedSharding(self.sharding.mesh, P()) 519 520 random_agent = RandomAgent() 521 random_agent_state = random_agent.init( 522 obs_space, action_space, jax.random.PRNGKey(0) 523 ) 524 env_state = env.reset(env_key) 525 526 trajectory, env_state = rollout( 527 env_fn=env.step, 528 action_fn=random_agent.compute_actions, 529 env_state=env_state, 530 agent_state=random_agent_state, 531 key=rollout_key, 532 rollout_length=rollout_length, 533 env_extra_fields=("ori_obs", "termination"), 534 ) 535 536 # [T, B, ...] -> [T*B, ...] 537 trajectory = clean_trajectory(trajectory) 538 trajectory = flatten_rollout_trajectory(trajectory) 539 trajectory = tree_stop_gradient(trajectory) 540 trajectory = jax.device_put(trajectory, shared_sharding) 541 542 if pop_workflow_state.agent_state.obs_preprocessor_state is not None: 543 # update all obs_preprocessor_state by the random trajectory 544 545 obs_preprocessor_state = ( 546 pop_workflow_state.agent_state.obs_preprocessor_state 547 ) 548 549 obs_preprocessor_state = shmap_vmap( 550 running_statistics.update, 551 mesh=self.sharding.mesh, 552 in_specs=(shared_sharding.spec, P()), 553 out_specs=self.sharding.spec, 554 check_rep=False, 555 )(obs_preprocessor_state, trajectory.obs) 556 557 pop_workflow_state = pop_workflow_state.replace( 558 agent_state=pop_workflow_state.agent_state.replace( 559 obs_preprocessor_state=obs_preprocessor_state 560 ) 561 ) 562 563 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 564 565 sampled_timesteps_m = rollout_length * num_envs / 1e6 566 workflow_metrics = state.metrics.replace( 567 sampled_timesteps_m=state.metrics.sampled_timesteps_m + sampled_timesteps_m, 568 ) 569 570 return state.replace( 571 key=key, 572 metrics=workflow_metrics, 573 replay_buffer_state=replay_buffer_state, 574 ) 575
[docs] 576 def step(self, state: State) -> tuple[MetricBase, State]: 577 pop_workflow_state = state.pop_workflow_state 578 pop = state.pop 579 replay_buffer_state = state.replay_buffer_state 580 pbt_opt_state = state.pbt_opt_state 581 582 # ===== step ====== 583 def _train_steps(pop_wf_state, replay_buffer_state): 584 def _one_step(carry, _): 585 pop_wf_state, replay_buffer_state = carry 586 587 def _wf_step_wrapper(wf_state): 588 wf_state = wf_state.replace(replay_buffer_state=replay_buffer_state) 589 train_metrics, wf_state = self.workflow.step(wf_state) 590 wf_state = wf_state.replace(replay_buffer_state=None) 591 return train_metrics, wf_state 592 593 pop_train_metrics, pop_wf_state = jax.vmap(_wf_step_wrapper)( 594 pop_wf_state 595 ) 596 597 # add replay buffer data: 598 # [pop, T*B, ...] -> [pop*T*B, ...] 599 trajectory = jtu.tree_map( 600 lambda x: jax.lax.collapse(x, 0, 2), pop_train_metrics.trajectory 601 ) 602 trajectory = jax.lax.all_gather( 603 trajectory, POP_AXIS_NAME, axis=0, tiled=True 604 ) 605 606 replay_buffer_state = self.replay_buffer.add( 607 replay_buffer_state, trajectory 608 ) 609 pop_train_metrics = pop_train_metrics.replace(trajectory=None) 610 611 return (pop_wf_state, replay_buffer_state), pop_train_metrics 612 613 (pop_wf_state, replay_buffer_state), train_metrics = scan_and_last( 614 _one_step, 615 (pop_wf_state, replay_buffer_state), 616 (), 617 length=self.config.workflow_steps_per_iter, 618 ) 619 620 return train_metrics, pop_wf_state, replay_buffer_state 621 622 pop_train_metrics, pop_workflow_state, replay_buffer_state = shard_map( 623 _train_steps, 624 mesh=self.sharding.mesh, 625 in_specs=(P(POP_AXIS_NAME), P()), 626 out_specs=(P(POP_AXIS_NAME), P(POP_AXIS_NAME), P()), 627 check_rep=False, 628 )(pop_workflow_state, replay_buffer_state) 629 630 # ===== eval ====== 631 eval_fn = shmap_vmap( 632 self.workflow.evaluate, 633 mesh=self.sharding.mesh, 634 in_specs=self.sharding.spec, 635 out_specs=self.sharding.spec, 636 check_rep=False, 637 ) 638 639 pop_eval_metrics, pop_workflow_state = eval_fn(pop_workflow_state) 640 641 # customize your pop metrics here 642 pop_episode_returns = pop_eval_metrics.episode_returns 643 644 # ===== warmup or exploit & explore ====== 645 key, exploit_and_explore_key = jax.random.split(state.key) 646 647 def _dummy_fn(pbt_opt_state, pop, pop_workflow_state, pop_metrics, key): 648 return pop, pop_workflow_state, pbt_opt_state 649 650 pop, pop_workflow_state, pbt_opt_state = jax.lax.cond( 651 state.metrics.iterations + 1 652 <= math.ceil( 653 self.config.warmup_steps / self.config.workflow_steps_per_iter 654 ), 655 _dummy_fn, 656 self.exploit_and_explore, 657 pbt_opt_state, 658 pop, 659 pop_workflow_state, 660 pop_episode_returns, 661 exploit_and_explore_key, 662 ) 663 664 # ===== record metrics ====== 665 workflow_metrics = state.metrics.replace( 666 sampled_timesteps_m=jnp.sum( 667 pop_workflow_state.metrics.sampled_timesteps / 1e6 668 ), # convert uint32 to float32 669 iterations=state.metrics.iterations + 1, 670 ) 671 672 train_metrics = PBTOffpolicyTrainMetric( 673 pop_episode_returns=pop_eval_metrics.episode_returns, 674 pop_episode_lengths=pop_eval_metrics.episode_lengths, 675 pop_train_metrics=pop_train_metrics, 676 pop=state.pop, # save prev pop instead of new pop to match the metrics 677 rb_size=replay_buffer_state.buffer_size, 678 ) 679 680 return train_metrics, state.replace( 681 key=key, 682 metrics=workflow_metrics, 683 pop=pop, 684 pop_workflow_state=pop_workflow_state, 685 pbt_opt_state=pbt_opt_state, 686 replay_buffer_state=replay_buffer_state, 687 )
688 689 def _record_step_metrics(self, train_metrics, workflow_metrics, iters): 690 train_metrics_dict = train_metrics.to_local_dict() 691 692 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 693 train_metrics_dict["pop_episode_returns"], histogram=True 694 ) 695 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 696 train_metrics_dict["pop_episode_lengths"], histogram=True 697 ) 698 train_metrics_dict["pop"] = convert_pop_to_df(train_metrics_dict["pop"]) 699 train_metrics_dict["pop_train_metrics"] = jtu.tree_map( 700 get_1d_array_statistics, train_metrics_dict["pop_train_metrics"] 701 ) 702 703 self.recorder.write(workflow_metrics.to_local_dict(), iters) 704 self.recorder.write(train_metrics_dict, iters) 705
[docs] 706 def learn(self, state: State) -> State: 707 for i in range(state.metrics.iterations, self.config.num_iters): 708 iters = i + 1 709 train_metrics, state = self.step(state) 710 workflow_metrics = state.metrics 711 712 self._record_step_metrics(train_metrics, workflow_metrics, iters) 713 714 if iters % self.config.eval_interval == 0 or iters == self.config.num_iters: 715 eval_metrics, state = self.evaluate(state) 716 717 eval_metrics_dict = jtu.tree_map( 718 get_1d_array, 719 eval_metrics.to_local_dict(), 720 ) 721 722 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 723 724 saved_state = state 725 if not self.config.save_replay_buffer: 726 saved_state = skip_replay_buffer_state(saved_state) 727 self.checkpoint_manager.save( 728 iters, 729 saved_state, 730 force=iters == self.config.num_iters, 731 ) 732 733 return state
734
[docs] 735 @classmethod 736 def enable_jit(cls) -> None: 737 super().enable_jit() 738 cls._postsetup_replaybuffer = jax.jit( 739 cls._postsetup_replaybuffer, static_argnums=(0,) 740 )