Source code for evorl.algorithms.td7

  1import logging
  2from typing import Any, Sequence
  3import math
  4
  5import chex
  6import flax.linen as nn
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11from omegaconf import DictConfig
 12
 13from evorl.distributed import psum, pmean
 14from evorl.distributed.gradients import gradient_update
 15from evorl.envs import AutoresetMode, Box, create_env, Space
 16from evorl.evaluators import Evaluator
 17from evorl.metrics import MetricBase, metric_field
 18from evorl.sample_batch import SampleBatch
 19from evorl.types import (
 20    Action,
 21    LossDict,
 22    Params,
 23    PolicyExtraInfo,
 24    PyTreeData,
 25    PyTreeDict,
 26    State,
 27    pytree_field,
 28)
 29from evorl.utils import running_statistics
 30from evorl.utils.jax_utils import (
 31    scan_and_mean,
 32    tree_stop_gradient,
 33    tree_get,
 34    right_shift_with_padding,
 35)
 36from evorl.evaluators import EpisodeCollector
 37from evorl.agent import Agent, AgentState
 38from evorl.replay_buffers import LAPReplayBuffer
 39from evorl.recorders import add_prefix
 40from evorl.networks.linear import MLP
 41
 42from .offpolicy_utils import OffPolicyWorkflowTemplate, skip_replay_buffer_state
 43
 44logger = logging.getLogger(__name__)
 45
 46
[docs] 47def avg_l1_norm(x: jax.Array, eps: float = 1e-8) -> jax.Array: 48 """Average L1 Norm used in TD7.""" 49 mean_abs = jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=eps) 50 return x / mean_abs
51 52
[docs] 53class TD7Encoder(nn.Module): 54 z_s_dim: int = 256 55 z_sa_dim: int = 256 56 f_layer_sizes: Sequence[int] = (256, 256) 57 g_layer_sizes: Sequence[int] = (256, 256) 58
[docs] 59 def setup(self): 60 self.zs_mlp = MLP( 61 layer_sizes=tuple(self.f_layer_sizes) + (self.z_s_dim,), 62 activation=nn.elu, 63 name="zs_mlp", 64 ) 65 self.zsa_mlp = MLP( 66 layer_sizes=tuple(self.g_layer_sizes) + (self.z_sa_dim,), 67 activation=nn.elu, 68 name="zsa_mlp", 69 )
70
[docs] 71 def zs(self, obs: jax.Array) -> jax.Array: 72 z = self.zs_mlp(obs) 73 return avg_l1_norm(z)
74
[docs] 75 def zsa(self, z_s: jax.Array, action: jax.Array) -> jax.Array: 76 z = jnp.concatenate([z_s, action], axis=-1) 77 return self.zsa_mlp(z)
78 79 def __call__( 80 self, obs: jax.Array, action: jax.Array 81 ) -> tuple[jax.Array, jax.Array]: 82 # Utility method to make Flax initialization easier 83 z_s = self.zs(obs) 84 z_sa = self.zsa(z_s, action) 85 return z_s, z_sa
86 87
[docs] 88class TD7Actor(nn.Module): 89 action_size: int 90 z_s_dim: int = 256 91 state_emb_dim: int = 256 92 hidden_layer_sizes: Sequence[int] = (256, 256) 93 94 @nn.compact 95 def __call__(self, obs: jax.Array, z_s: jax.Array) -> jax.Array: 96 a = nn.Dense(self.state_emb_dim, name="l0")(obs) 97 a = avg_l1_norm(a) 98 a = jnp.concatenate([a, z_s], axis=-1) 99 100 a = MLP( 101 layer_sizes=tuple(self.hidden_layer_sizes) + (self.action_size,), 102 activation=nn.relu, 103 name="actor_mlp", 104 )(a) 105 106 return nn.tanh(a)
107 108
[docs] 109class TD7Critic(nn.Module): 110 z_s_dim: int = 256 111 z_sa_dim: int = 256 112 state_action_emb_dim: int = 256 113 hidden_layer_sizes: Sequence[int] = (256, 256) 114 115 @nn.compact 116 def __call__( 117 self, obs: jax.Array, action: jax.Array, z_sa: jax.Array, z_s: jax.Array 118 ) -> jax.Array: 119 sa = jnp.concatenate([obs, action], axis=-1) 120 121 # q1 network 122 q1 = nn.Dense(self.state_action_emb_dim, name="q1_0")(sa) 123 q1 = avg_l1_norm(q1) 124 q1 = jnp.concatenate([q1, z_sa, z_s], axis=-1) 125 q1 = MLP( 126 layer_sizes=tuple(self.hidden_layer_sizes) + (1,), 127 activation=nn.elu, 128 name="q1_mlp", 129 )(q1) 130 131 # q2 network 132 q2 = nn.Dense(self.state_action_emb_dim, name="q2_0")(sa) 133 q2 = avg_l1_norm(q2) 134 q2 = jnp.concatenate([q2, z_sa, z_s], axis=-1) 135 q2 = MLP( 136 layer_sizes=tuple(self.hidden_layer_sizes) + (1,), 137 activation=nn.elu, 138 name="q2_mlp", 139 )(q2) 140 141 return jnp.concatenate([q1, q2], axis=-1)
142 143
[docs] 144class TD7TrainMetric(MetricBase): 145 critic_loss: chex.Array 146 actor_loss: chex.Array 147 encoder_loss: chex.Array 148 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
149 150
[docs] 151class TD7NetworkParams(PyTreeData): 152 actor_params: Params 153 critic_params: Params 154 encoder_params: Params 155 target_actor_params: Params 156 target_critic_params: Params 157 fixed_encoder_params: Params 158 fixed_encoder_target_params: Params 159 checkpoint_actor_params: Params 160 checkpoint_encoder_params: Params
161 162
[docs] 163class TD7Agent(Agent): 164 """The Agent for TD7.""" 165 166 critic_network: nn.Module 167 actor_network: nn.Module 168 encoder_network: nn.Module 169 obs_preprocessor: Any = pytree_field(default=None, static=True) 170 171 discount: float = 0.99 172 exploration_epsilon: float = 0.1 173 policy_noise: float = 0.2 174 clip_policy_noise: float = 0.5 175 min_priority: float = 1.0 176 177 @property 178 def normalize_obs(self): 179 return self.obs_preprocessor is not None 180
[docs] 181 def init( 182 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 183 ) -> AgentState: 184 key, q_key, actor_key, enc_key = jax.random.split(key, num=4) 185 186 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 187 dummy_action = action_space.sample(key)[None, ...] 188 189 encoder_params = self.encoder_network.init(enc_key, dummy_obs, dummy_action) 190 191 # Need z_s and z_sa to pass to other networks 192 dummy_z_s, dummy_z_sa = self.encoder_network.apply( 193 encoder_params, dummy_obs, dummy_action 194 ) 195 196 critic_params = self.critic_network.init( 197 q_key, dummy_obs, dummy_action, dummy_z_sa, dummy_z_s 198 ) 199 200 actor_params = self.actor_network.init(actor_key, dummy_obs, dummy_z_s) 201 202 params_state = TD7NetworkParams( 203 encoder_params=encoder_params, 204 actor_params=actor_params, 205 critic_params=critic_params, 206 fixed_encoder_params=encoder_params, 207 target_actor_params=actor_params, 208 target_critic_params=critic_params, 209 fixed_encoder_target_params=encoder_params, 210 checkpoint_actor_params=actor_params, 211 checkpoint_encoder_params=encoder_params, 212 ) 213 214 if self.normalize_obs: 215 obs_preprocessor_state = running_statistics.init_state( 216 tree_get(dummy_obs, 0) 217 ) 218 else: 219 obs_preprocessor_state = None 220 221 # Value clipping states and best performances 222 extra_state = PyTreeDict( 223 max_q=jnp.array(jnp.finfo(jnp.float32).min, dtype=jnp.float32), 224 min_q=jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32), 225 max_target=jnp.array(0.0, dtype=jnp.float32), 226 min_target=jnp.array(0.0, dtype=jnp.float32), 227 best_perf=jnp.array(jnp.finfo(jnp.float32).min, dtype=jnp.float32), 228 ) 229 230 return AgentState( 231 params=params_state, 232 obs_preprocessor_state=obs_preprocessor_state, 233 extra_state=extra_state, 234 )
235
[docs] 236 def compute_actions( 237 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 238 ) -> tuple[Action, PolicyExtraInfo]: 239 obs = sample_batch.obs 240 if self.normalize_obs: 241 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 242 243 # Uses the fixed_encoder to get the state embedding for the actor 244 z_s = self.encoder_network.apply( 245 agent_state.params.fixed_encoder_params, obs, method="zs" 246 ) 247 actions = self.actor_network.apply(agent_state.params.actor_params, obs, z_s) 248 249 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon 250 actions += noise 251 actions = jnp.clip(actions, -1.0, 1.0) 252 253 return actions, PyTreeDict()
254
[docs] 255 def evaluate_actions( 256 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 257 ) -> tuple[Action, PolicyExtraInfo]: 258 obs = sample_batch.obs 259 if self.normalize_obs: 260 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 261 262 # Evaluation uses checkpoint_encoder and checkpoint_actor 263 z_s = self.encoder_network.apply( 264 agent_state.params.checkpoint_encoder_params, obs, method="zs" 265 ) 266 actions = self.actor_network.apply( 267 agent_state.params.checkpoint_actor_params, obs, z_s 268 ) 269 270 return actions, PyTreeDict()
271
[docs] 272 def encoder_loss( 273 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 274 ) -> LossDict: 275 next_obs = sample_batch.extras.env_extras.ori_obs 276 obs = sample_batch.obs 277 actions = sample_batch.actions 278 279 if self.normalize_obs: 280 next_obs = self.obs_preprocessor( 281 next_obs, agent_state.obs_preprocessor_state 282 ) 283 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 284 285 next_z_s = self.encoder_network.apply( 286 agent_state.params.encoder_params, next_obs, method="zs" 287 ) 288 next_z_s = jax.lax.stop_gradient(next_z_s) 289 290 z_s = self.encoder_network.apply( 291 agent_state.params.encoder_params, obs, method="zs" 292 ) 293 pred_z_sa = self.encoder_network.apply( 294 agent_state.params.encoder_params, z_s, actions, method="zsa" 295 ) 296 297 enc_loss = optax.squared_error(pred_z_sa, next_z_s).mean() 298 return PyTreeDict(encoder_loss=enc_loss)
299
[docs] 300 def critic_loss( 301 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 302 ) -> LossDict: 303 next_obs = sample_batch.extras.env_extras.ori_obs 304 obs = sample_batch.obs 305 actions = sample_batch.actions 306 307 if self.normalize_obs: 308 next_obs = self.obs_preprocessor( 309 next_obs, agent_state.obs_preprocessor_state 310 ) 311 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 312 313 # Target Q computation 314 fixed_target_z_s = self.encoder_network.apply( 315 agent_state.params.fixed_encoder_target_params, next_obs, method="zs" 316 ) 317 318 noise = jnp.clip( 319 jax.random.normal(key, actions.shape) * self.policy_noise, 320 -self.clip_policy_noise, 321 self.clip_policy_noise, 322 ) 323 next_actions = self.actor_network.apply( 324 agent_state.params.target_actor_params, next_obs, fixed_target_z_s 325 ) 326 next_actions = jnp.clip(next_actions + noise, -1.0, 1.0) 327 328 fixed_target_z_sa = self.encoder_network.apply( 329 agent_state.params.fixed_encoder_target_params, 330 fixed_target_z_s, 331 next_actions, 332 method="zsa", 333 ) 334 335 next_qs = self.critic_network.apply( 336 agent_state.params.target_critic_params, 337 next_obs, 338 next_actions, 339 fixed_target_z_sa, 340 fixed_target_z_s, 341 ) 342 next_qs_min = next_qs.min(-1) 343 344 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 345 346 # Value clipping from extra_state 347 min_target = agent_state.extra_state.min_target 348 max_target = agent_state.extra_state.max_target 349 350 q_target = sample_batch.rewards + discounts * jnp.clip( 351 next_qs_min, min_target, max_target 352 ) 353 q_target = jnp.broadcast_to(q_target[..., None], (*q_target.shape, 2)) 354 q_target = jax.lax.stop_gradient(q_target) 355 356 # Current Q computation 357 fixed_z_s = self.encoder_network.apply( 358 agent_state.params.fixed_encoder_params, obs, method="zs" 359 ) 360 fixed_z_sa = self.encoder_network.apply( 361 agent_state.params.fixed_encoder_params, fixed_z_s, actions, method="zsa" 362 ) 363 364 qs = self.critic_network.apply( 365 agent_state.params.critic_params, obs, actions, fixed_z_sa, fixed_z_s 366 ) 367 368 td_error = jnp.abs(qs - q_target) 369 370 # LAP huber loss 371 critic_loss = ( 372 jnp.where( 373 td_error < self.min_priority, 374 0.5 * jnp.square(td_error), 375 self.min_priority * td_error, 376 ) 377 .sum(-1) 378 .mean() 379 ) 380 381 # Update running max/min Q values (using global min/max) 382 batch_q_max = q_target[..., 0].max() 383 batch_q_min = q_target[..., 0].min() 384 385 # Compute priority updates 386 priority = jnp.maximum(td_error.max(axis=-1), self.min_priority) 387 388 return PyTreeDict( 389 critic_loss=critic_loss, 390 q_value=qs.mean(), 391 priority=jax.lax.stop_gradient(priority), 392 batch_q_max=jax.lax.stop_gradient(batch_q_max), 393 batch_q_min=jax.lax.stop_gradient(batch_q_min), 394 )
395
[docs] 396 def actor_loss( 397 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 398 ) -> LossDict: 399 obs = sample_batch.obs 400 401 if self.normalize_obs: 402 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 403 404 fixed_z_s = self.encoder_network.apply( 405 agent_state.params.fixed_encoder_params, obs, method="zs" 406 ) 407 actions = self.actor_network.apply( 408 agent_state.params.actor_params, obs, fixed_z_s 409 ) 410 411 fixed_z_sa = self.encoder_network.apply( 412 agent_state.params.fixed_encoder_params, fixed_z_s, actions, method="zsa" 413 ) 414 415 qs = self.critic_network.apply( 416 agent_state.params.critic_params, obs, actions, fixed_z_sa, fixed_z_s 417 ) 418 419 actor_loss = -jnp.mean(qs) 420 return PyTreeDict(actor_loss=actor_loss)
421 422
[docs] 423def make_td7_agent( 424 action_space: Space, 425 z_s_dim: int = 256, 426 z_sa_dim: int = 256, 427 f_layer_sizes: Sequence[int] = (256, 256), 428 g_layer_sizes: Sequence[int] = (256, 256), 429 state_emb_dim: int = 256, 430 state_action_emb_dim: int = 256, 431 critic_hidden_layer_sizes: Sequence[int] = (256, 256), 432 actor_hidden_layer_sizes: Sequence[int] = (256, 256), 433 discount: float = 0.99, 434 exploration_epsilon: float = 0.1, 435 policy_noise: float = 0.2, 436 clip_policy_noise: float = 0.5, 437 min_priority: float = 1.0, 438 normalize_obs: bool = False, 439): 440 assert isinstance(action_space, Box), "Only continue action space is supported." 441 442 action_size = action_space.shape[0] 443 444 encoder_network = TD7Encoder( 445 z_s_dim=z_s_dim, 446 z_sa_dim=z_sa_dim, 447 f_layer_sizes=f_layer_sizes, 448 g_layer_sizes=g_layer_sizes, 449 ) 450 critic_network = TD7Critic( 451 z_s_dim=z_s_dim, 452 z_sa_dim=z_sa_dim, 453 state_action_emb_dim=state_action_emb_dim, 454 hidden_layer_sizes=critic_hidden_layer_sizes, 455 ) 456 actor_network = TD7Actor( 457 z_s_dim=z_s_dim, 458 state_emb_dim=state_emb_dim, 459 hidden_layer_sizes=actor_hidden_layer_sizes, 460 action_size=action_size, 461 ) 462 463 if normalize_obs: 464 obs_preprocessor = running_statistics.normalize 465 else: 466 obs_preprocessor = None 467 468 return TD7Agent( 469 encoder_network=encoder_network, 470 critic_network=critic_network, 471 actor_network=actor_network, 472 obs_preprocessor=obs_preprocessor, 473 discount=discount, 474 exploration_epsilon=exploration_epsilon, 475 policy_noise=policy_noise, 476 clip_policy_noise=clip_policy_noise, 477 min_priority=min_priority, 478 )
479 480
[docs] 481class TD7Workflow(OffPolicyWorkflowTemplate):
[docs] 482 @classmethod 483 def name(cls): 484 return "TD7"
485 486 @classmethod 487 def _build_from_config(cls, config: DictConfig): 488 assert config.rollout_episodes % config.num_envs == 0, ( 489 "rollout_episodes must be divisible by num_envs" 490 ) 491 492 env = create_env( 493 config.env, 494 episode_length=config.env.max_episode_steps, 495 parallel=config.num_envs, 496 autoreset_mode=AutoresetMode.DISABLED, 497 record_ori_obs=True, 498 ) 499 500 agent = make_td7_agent( 501 action_space=env.action_space, 502 z_s_dim=config.agent_network.zs_dim, 503 z_sa_dim=config.agent_network.zsa_dim, 504 f_layer_sizes=config.agent_network.f_layer_sizes, 505 g_layer_sizes=config.agent_network.g_layer_sizes, 506 state_emb_dim=config.agent_network.state_emb_dim, 507 state_action_emb_dim=config.agent_network.state_action_emb_dim, 508 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 509 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 510 discount=config.discount, 511 exploration_epsilon=config.exploration_epsilon, 512 policy_noise=config.policy_noise, 513 clip_policy_noise=config.clip_policy_noise, 514 min_priority=config.min_priority, 515 normalize_obs=config.normalize_obs, 516 ) 517 518 if ( 519 config.optimizer.grad_clip_norm is not None 520 and config.optimizer.grad_clip_norm > 0 521 ): 522 optimizer = optax.chain( 523 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 524 optax.adam(config.optimizer.lr), 525 ) 526 else: 527 optimizer = optax.adam(config.optimizer.lr) 528 529 replay_buffer = LAPReplayBuffer( 530 capacity=config.replay_buffer_capacity, 531 sample_batch_size=config.batch_size, 532 alpha=config.lap_alpha, 533 ) 534 535 eval_env = create_env( 536 config.env, 537 episode_length=config.env.max_episode_steps, 538 parallel=config.num_eval_envs, 539 autoreset_mode=AutoresetMode.DISABLED, 540 ) 541 542 evaluator = Evaluator( 543 env=eval_env, 544 action_fn=agent.evaluate_actions, 545 max_episode_steps=config.env.max_episode_steps, 546 ) 547 548 collector = EpisodeCollector( 549 env=env, 550 action_fn=agent.compute_actions, 551 max_episode_steps=config.env.max_episode_steps, 552 env_extra_fields=("ori_obs", "termination"), 553 ) 554 555 workflow = cls( 556 env, 557 agent, 558 optimizer, 559 evaluator, 560 replay_buffer, 561 config, 562 ) 563 workflow.collector = collector 564 return workflow 565 566 def _setup_agent_and_optimizer( 567 self, key: chex.PRNGKey 568 ) -> tuple[AgentState, chex.ArrayTree]: 569 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 570 571 opt_state = PyTreeDict( 572 actor=self.optimizer.init(agent_state.params.actor_params), 573 critic=self.optimizer.init(agent_state.params.critic_params), 574 encoder=self.optimizer.init(agent_state.params.encoder_params), 575 ) 576 return agent_state, opt_state 577
[docs] 578 def step(self, state: State) -> tuple[MetricBase, State]: 579 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 580 581 # Evaluate via episodes (rollout [T, episodes, ...]) 582 eval_metrics, trajectory = self.collector.rollout( 583 state.agent_state, rollout_key, self.config.rollout_episodes 584 ) 585 586 trajectory = trajectory.replace(next_obs=None) 587 588 # Mask out padded steps based on `dones` array (since autoreset is OFF) 589 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1)) 590 trajectory = trajectory.replace(dones=None) 591 592 def _flatten_fn(x): 593 return x.reshape(-1, *x.shape[2:]) 594 595 trajectory = jtu.tree_map(_flatten_fn, trajectory) 596 mask = jtu.tree_map(_flatten_fn, mask) 597 598 trajectory, mask = tree_stop_gradient((trajectory, mask)) 599 600 agent_state = state.agent_state 601 if agent_state.obs_preprocessor_state is not None: 602 agent_state = agent_state.replace( 603 obs_preprocessor_state=running_statistics.update( 604 agent_state.obs_preprocessor_state, 605 trajectory.obs, 606 dp_axis_name=self.dp_axis_name, 607 ) 608 ) 609 610 replay_buffer_state = self.replay_buffer.add( 611 state.replay_buffer_state, trajectory, mask=mask 612 ) 613 614 # Gradient update wrappers (can't easily use agent_gradient_update because params are nested in TD7NetworkParams) 615 def encoder_loss_fn(params, agent_state, sample_batch, key): 616 # Evaluate using modified encoder params 617 temp_agent_state = agent_state.replace( 618 params=agent_state.params.replace(encoder_params=params) 619 ) 620 loss_dict = self.agent.encoder_loss(temp_agent_state, sample_batch, key) 621 return loss_dict.encoder_loss, loss_dict 622 623 def critic_loss_fn(params, agent_state, sample_batch, key): 624 temp_agent_state = agent_state.replace( 625 params=agent_state.params.replace(critic_params=params) 626 ) 627 loss_dict = self.agent.critic_loss(temp_agent_state, sample_batch, key) 628 return loss_dict.critic_loss, loss_dict 629 630 def actor_loss_fn(params, agent_state, sample_batch, key): 631 temp_agent_state = agent_state.replace( 632 params=agent_state.params.replace(actor_params=params) 633 ) 634 loss_dict = self.agent.actor_loss(temp_agent_state, sample_batch, key) 635 return loss_dict.actor_loss, loss_dict 636 637 encoder_update_fn = gradient_update( 638 encoder_loss_fn, 639 self.optimizer, 640 dp_axis_name=self.dp_axis_name, 641 has_aux=True, 642 ) 643 644 critic_update_fn = gradient_update( 645 critic_loss_fn, 646 self.optimizer, 647 dp_axis_name=self.dp_axis_name, 648 has_aux=True, 649 ) 650 651 actor_update_fn = gradient_update( 652 actor_loss_fn, 653 self.optimizer, 654 dp_axis_name=self.dp_axis_name, 655 has_aux=True, 656 ) 657 658 def _sample_and_update_fn(carry, t): 659 key, agent_state, opt_state, replay_state, training_steps = carry 660 661 enc_opt_state = opt_state.encoder 662 critic_opt_state = opt_state.critic 663 actor_opt_state = opt_state.actor 664 665 key, enc_key, critic_key, actor_key, rb_key = jax.random.split(key, num=5) 666 667 # Sample from Replay Buffer (yields batch + index tracking state in one pass) 668 sample_batch, _weights, replay_state = self.replay_buffer.sample( 669 replay_state, rb_key 670 ) 671 672 # 1. Update Encoder 673 (enc_loss, enc_loss_dict), enc_params, enc_opt_state = encoder_update_fn( 674 enc_opt_state, 675 agent_state.params.encoder_params, 676 agent_state, 677 sample_batch, 678 enc_key, 679 ) 680 agent_state = agent_state.replace( 681 params=agent_state.params.replace(encoder_params=enc_params) 682 ) 683 684 # 2. Update Critic 685 (critic_loss, critic_loss_dict), critic_params, critic_opt_state = ( 686 critic_update_fn( 687 critic_opt_state, 688 agent_state.params.critic_params, 689 agent_state, 690 sample_batch, 691 critic_key, 692 ) 693 ) 694 agent_state = agent_state.replace( 695 params=agent_state.params.replace(critic_params=critic_params) 696 ) 697 698 # LAP Priority updates 699 priority = critic_loss_dict.priority 700 replay_state = self.replay_buffer.update_priority(replay_state, priority) 701 702 # Value clipping min/max tracking updates 703 new_max = jnp.maximum( 704 agent_state.extra_state.max_q, critic_loss_dict.batch_q_max 705 ) 706 new_min = jnp.minimum( 707 agent_state.extra_state.min_q, critic_loss_dict.batch_q_min 708 ) 709 agent_state = agent_state.replace( 710 extra_state=agent_state.extra_state.replace( 711 max_q=new_max, 712 min_q=new_min, 713 ) 714 ) 715 716 # 3. Update Actor 717 def _update_actor(carry): 718 agent_state, actor_opt_state = carry 719 (actor_loss, actor_loss_dict), actor_params, actor_opt_state = ( 720 actor_update_fn( 721 actor_opt_state, 722 agent_state.params.actor_params, 723 agent_state, 724 sample_batch, 725 actor_key, 726 ) 727 ) 728 agent_state = agent_state.replace( 729 params=agent_state.params.replace(actor_params=actor_params) 730 ) 731 return agent_state, actor_opt_state, actor_loss, actor_loss_dict 732 733 def _skip_actor(carry): 734 agent_state, actor_opt_state = carry 735 return ( 736 agent_state, 737 actor_opt_state, 738 jnp.array(0.0), 739 PyTreeDict(actor_loss=jnp.array(0.0)), 740 ) 741 742 agent_state, actor_opt_state, actor_loss, actor_loss_dict = jax.lax.cond( 743 (training_steps + 1) % self.config.policy_freq == 0, 744 _update_actor, 745 _skip_actor, 746 (agent_state, actor_opt_state), 747 ) 748 749 # 4. Hard target updates 750 def _hard_target_updates(carry): 751 agent_state, replay_state = carry 752 agent_state = agent_state.replace( 753 params=agent_state.params.replace( 754 target_actor_params=agent_state.params.actor_params, 755 target_critic_params=agent_state.params.critic_params, 756 fixed_encoder_target_params=agent_state.params.fixed_encoder_params, 757 fixed_encoder_params=agent_state.params.encoder_params, 758 ), 759 extra_state=agent_state.extra_state.replace( 760 max_target=agent_state.extra_state.max_q, 761 min_target=agent_state.extra_state.min_q, 762 ), 763 ) 764 replay_state = self.replay_buffer.reset_max_priority(replay_state) 765 return agent_state, replay_state 766 767 def _skip_updates(carry): 768 return carry 769 770 agent_state, replay_state = jax.lax.cond( 771 (training_steps + 1) % self.config.target_update_rate == 0, 772 _hard_target_updates, 773 _skip_updates, 774 (agent_state, replay_state), 775 ) 776 777 opt_state = opt_state.replace( 778 encoder=enc_opt_state, actor=actor_opt_state, critic=critic_opt_state 779 ) 780 781 # We use zero for dummy actor losses if we didn't update it to avoid NaN downstream 782 return ( 783 (key, agent_state, opt_state, replay_state, training_steps + 1), 784 ( 785 enc_loss, 786 critic_loss, 787 actor_loss, 788 enc_loss_dict, 789 critic_loss_dict, 790 actor_loss_dict, 791 ), 792 ) 793 794 # Retrieve global training steps from state metrics iterations 795 global_steps = state.metrics.iterations * self.config.num_updates_per_iter 796 797 # Need to cast loop dummy variable to integer 798 iters = jnp.arange(self.config.num_updates_per_iter, dtype=jnp.int32) 799 800 ( 801 (_, agent_state, opt_state, replay_buffer_state, _), 802 ( 803 encoder_loss, 804 critic_loss, 805 actor_loss, 806 enc_loss_dict, 807 critic_loss_dict, 808 actor_loss_dict, 809 ), 810 ) = scan_and_mean( 811 _sample_and_update_fn, 812 ( 813 learn_key, 814 agent_state, 815 state.opt_state, 816 replay_buffer_state, 817 global_steps, 818 ), 819 iters, 820 length=self.config.num_updates_per_iter, 821 ) 822 823 # Episodic Checkpointing evaluate & replace 824 if self.config.checkpoint_metric == "mean": 825 perf = jnp.mean(eval_metrics.episode_returns) 826 elif self.config.checkpoint_metric == "min": 827 perf = jnp.min(eval_metrics.episode_returns) 828 elif self.config.checkpoint_metric == "max": 829 perf = jnp.max(eval_metrics.episode_returns) 830 else: 831 raise ValueError( 832 f"Unsupported checkpoint metric: {self.config.checkpoint_metric}. " 833 "Must be one of 'min', 'max', or 'mean'." 834 ) 835 836 def _update_checkpoint(ag_state): 837 return ag_state.replace( 838 params=ag_state.params.replace( 839 checkpoint_actor_params=ag_state.params.actor_params, 840 checkpoint_encoder_params=ag_state.params.fixed_encoder_params, 841 ), 842 extra_state=ag_state.extra_state.replace(best_perf=perf), 843 ) 844 845 agent_state = jax.lax.cond( 846 perf >= agent_state.extra_state.best_perf, 847 _update_checkpoint, 848 lambda ag_state: ag_state, 849 agent_state, 850 ) 851 852 # actor loss would be divided by policy_freq effectively (thanks to zeros) 853 # So multiply back by policy_freq to get the real mean 854 actor_loss = actor_loss * self.config.policy_freq 855 856 train_metrics = TD7TrainMetric( 857 encoder_loss=encoder_loss, 858 actor_loss=actor_loss, 859 critic_loss=critic_loss, 860 raw_loss_dict=PyTreeDict( 861 {**enc_loss_dict, **critic_loss_dict, **actor_loss_dict} 862 ), 863 ).all_reduce(dp_axis_name=self.dp_axis_name) 864 865 sampled_timesteps = jnp.uint32(eval_metrics.episode_lengths.sum()) 866 sampled_timesteps = psum(sampled_timesteps, axis_name=self.dp_axis_name) 867 868 sampled_epsiodes = psum( 869 jnp.uint32(self.config.rollout_episodes), axis_name=self.dp_axis_name 870 ) 871 872 workflow_metrics = state.metrics.replace( 873 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 874 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 875 iterations=state.metrics.iterations + 1, 876 ).all_reduce(dp_axis_name=self.dp_axis_name) 877 878 # state.env_state is irrelevant here since episodecollector handles reset internally. 879 # we can just pass original env_state unmodified since disabled autoreset env ignores it generally across steps (or we just maintain step 0). 880 return train_metrics, state.replace( 881 key=key, 882 metrics=workflow_metrics, 883 agent_state=agent_state, 884 replay_buffer_state=replay_buffer_state, 885 opt_state=opt_state, 886 )
887
[docs] 888 def learn(self, state: State) -> State: 889 num_devices = jax.device_count() 890 one_step_episodes = self.config.rollout_episodes * num_devices 891 sampled_episodes = state.metrics.sampled_episodes.tolist() 892 num_iters = math.ceil( 893 (self.config.total_episodes - sampled_episodes) 894 / (one_step_episodes * self.config.fold_iters) 895 ) 896 start_iteration = state.metrics.iterations.tolist() 897 final_iteration = num_iters + start_iteration 898 899 for i in range(start_iteration, final_iteration): 900 iterations = i + 1 901 train_metrics, state = self._multi_steps(state) 902 workflow_metrics = state.metrics 903 904 self.recorder.write(train_metrics.to_local_dict(), iterations) 905 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 906 907 if ( 908 iterations % self.config.eval_interval == 0 909 or iterations == final_iteration 910 ): 911 eval_metrics, state = self.evaluate(state) 912 self.recorder.write( 913 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 914 ) 915 916 saved_state = state 917 if not self.config.save_replay_buffer: 918 saved_state = skip_replay_buffer_state(saved_state) 919 self.checkpoint_manager.save( 920 iterations, saved_state, force=iterations == final_iteration 921 ) 922 923 return state