Source code for evorl.algorithms.sac

  1import logging
  2from typing import Any
  3from omegaconf import DictConfig
  4
  5import chex
  6import flax.linen as nn
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11
 12from evorl.replay_buffers import ReplayBuffer
 13from evorl.distributed import agent_gradient_update, psum, pmean
 14from evorl.distribution import get_tanh_norm_dist, get_categorical_dist
 15from evorl.envs import AutoresetMode, Box, create_env, Space, Discrete
 16from evorl.evaluators import Evaluator
 17from evorl.metrics import MetricBase, metric_field
 18from evorl.networks import make_policy_network, make_q_network, make_discrete_q_network
 19from evorl.rollout import rollout
 20from evorl.sample_batch import SampleBatch
 21from evorl.types import (
 22    Action,
 23    LossDict,
 24    Params,
 25    PolicyExtraInfo,
 26    PyTreeData,
 27    PyTreeDict,
 28    State,
 29    pytree_field,
 30)
 31from evorl.utils import running_statistics
 32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
 33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 34
 35from evorl.agent import Agent, AgentState
 36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
 37
 38logger = logging.getLogger(__name__)
 39
 40
[docs] 41class SACTrainMetric(MetricBase): 42 critic_loss: chex.Array 43 actor_loss: chex.Array 44 alpha_loss: chex.Array | None = None 45 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
46 47
[docs] 48class SACNetworkParams(PyTreeData): 49 critic_params: Params 50 target_critic_params: Params 51 actor_params: Params 52 log_alpha: Params
53 54
[docs] 55class SACAgent(Agent): 56 critic_network: nn.Module 57 actor_network: nn.Module 58 obs_preprocessor: Any = pytree_field(default=None, static=True) 59 60 init_alpha: float = 1.0 61 discount: float = 0.99 62 63 @property 64 def normalize_obs(self): 65 return self.obs_preprocessor is not None 66
[docs] 67 def init( 68 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 69 ) -> AgentState: 70 key, critic_key, actor_key = jax.random.split(key, num=3) 71 72 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 73 dummy_action = action_space.sample(key)[None, ...] 74 75 critic_params = self.critic_network.init(critic_key, dummy_obs, dummy_action) 76 target_critic_params = critic_params 77 78 actor_params = self.actor_network.init(actor_key, dummy_obs) 79 80 log_alpha = jnp.log(jnp.float32(self.init_alpha)) 81 82 params_state = SACNetworkParams( 83 critic_params=critic_params, 84 target_critic_params=target_critic_params, 85 actor_params=actor_params, 86 log_alpha=log_alpha, 87 ) 88 89 if self.normalize_obs: 90 # Note: statistics are broadcasted to [T*B] 91 obs_preprocessor_state = running_statistics.init_state( 92 tree_get(dummy_obs, 0) 93 ) 94 else: 95 obs_preprocessor_state = None 96 97 target_entropy = -jnp.prod(jnp.array(action_space.shape, dtype=jnp.float32)) 98 99 return AgentState( 100 params=params_state, 101 obs_preprocessor_state=obs_preprocessor_state, 102 extra_state=PyTreeDict(target_entropy=target_entropy), # the constant 103 )
104
[docs] 105 def compute_actions( 106 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 107 ) -> tuple[Action, PolicyExtraInfo]: 108 obs = sample_batch.obs 109 if self.normalize_obs: 110 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 111 112 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 113 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 114 actions = actions_dist.sample(seed=key) 115 return actions, PyTreeDict()
116
[docs] 117 def evaluate_actions( 118 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 119 ) -> tuple[Action, PolicyExtraInfo]: 120 obs = sample_batch.obs 121 if self.normalize_obs: 122 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 123 124 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 125 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 126 actions = actions_dist.mode() 127 return actions, PyTreeDict()
128
[docs] 129 def alpha_loss( 130 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 131 ) -> LossDict: 132 obs = sample_batch.obs 133 if self.normalize_obs: 134 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 135 136 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 137 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 138 actions = actions_dist.sample(seed=key) 139 actions_logp = actions_dist.log_prob(actions) 140 141 target_entropy = agent_state.extra_state.target_entropy 142 # official impl: 143 alpha = jnp.exp(agent_state.params.log_alpha) 144 alpha_loss = jnp.mean( 145 -alpha * jax.lax.stop_gradient(actions_logp + target_entropy) 146 ) 147 148 # another impl: see stable-baselines3/issues/36 149 # alpha_loss = (- agent_state.params.log_alpha * 150 # jax.lax.stop_gradient(actions_logp + target_entropy)).mean() 151 152 return PyTreeDict( 153 alpha_loss=alpha_loss, log_alpha=agent_state.params.log_alpha, alpha=alpha 154 )
155
[docs] 156 def actor_loss( 157 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 158 ) -> LossDict: 159 actor_key, entropy_key = jax.random.split(key, 2) 160 obs = sample_batch.obs 161 if self.normalize_obs: 162 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 163 164 alpha = jnp.exp(agent_state.params.log_alpha) 165 166 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 167 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 168 actions = actions_dist.sample(seed=actor_key) 169 actions_logp = actions_dist.log_prob(actions) 170 171 # [B, 2] 172 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions) 173 qs_min = jnp.min(qs, axis=-1) 174 actor_loss = jnp.mean(alpha * actions_logp - qs_min) 175 entropy = actions_dist.entropy(seed=entropy_key).mean() 176 177 return PyTreeDict(actor_loss=actor_loss, entropy=entropy)
178
[docs] 179 def critic_loss( 180 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 181 ) -> LossDict: 182 obs = sample_batch.obs 183 next_obs = sample_batch.extras.env_extras.ori_obs 184 if self.normalize_obs: 185 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 186 next_obs = self.obs_preprocessor( 187 next_obs, agent_state.obs_preprocessor_state 188 ) 189 190 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 191 192 alpha = jnp.exp(agent_state.params.log_alpha) 193 194 # [B, 2] 195 qs = self.critic_network.apply( 196 agent_state.params.critic_params, obs, sample_batch.actions 197 ) 198 199 next_raw_actions = self.actor_network.apply( 200 agent_state.params.actor_params, next_obs 201 ) 202 next_actions_dist = get_tanh_norm_dist(*jnp.split(next_raw_actions, 2, axis=-1)) 203 next_actions = next_actions_dist.sample(seed=key) 204 next_actions_logp = next_actions_dist.log_prob(next_actions) 205 # [B, 2] 206 next_qs = self.critic_network.apply( 207 agent_state.params.target_critic_params, next_obs, next_actions 208 ) 209 qs_target = sample_batch.rewards + discounts * ( 210 jnp.min(next_qs, axis=-1) - alpha * next_actions_logp 211 ) 212 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2)) 213 214 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean() 215 return PyTreeDict(critic_loss=q_loss)
216 217
[docs] 218class SACDiscreteAgent(Agent): 219 critic_network: nn.Module 220 actor_network: nn.Module 221 obs_preprocessor: Any = pytree_field(default=None, static=True) 222 223 init_alpha: float = 1.0 224 discount: float = 0.99 225 target_entropy_ratio: float = 0.98 226 227 @property 228 def normalize_obs(self): 229 return self.obs_preprocessor is not None 230
[docs] 231 def init( 232 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 233 ) -> AgentState: 234 key, critic_key, actor_key = jax.random.split(key, num=3) 235 236 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 237 238 critic_params = self.critic_network.init(critic_key, dummy_obs) 239 target_critic_params = critic_params 240 241 actor_params = self.actor_network.init(actor_key, dummy_obs) 242 243 log_alpha = jnp.log(jnp.float32(self.init_alpha)) 244 245 params_state = SACNetworkParams( 246 critic_params=critic_params, 247 target_critic_params=target_critic_params, 248 actor_params=actor_params, 249 log_alpha=log_alpha, 250 ) 251 252 if self.normalize_obs: 253 # Note: statistics are broadcasted to [T*B] 254 obs_preprocessor_state = running_statistics.init_state( 255 tree_get(dummy_obs, 0) 256 ) 257 else: 258 obs_preprocessor_state = None 259 260 target_entropy = self.target_entropy_ratio * jnp.log( 261 jnp.float32(action_space.n) 262 ) 263 264 return AgentState( 265 params=params_state, 266 obs_preprocessor_state=obs_preprocessor_state, 267 extra_state=PyTreeDict(target_entropy=target_entropy), # the constant 268 )
269
[docs] 270 def compute_actions( 271 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 272 ) -> tuple[Action, PolicyExtraInfo]: 273 obs = sample_batch.obs 274 if self.normalize_obs: 275 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 276 277 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 278 actions_dist = get_categorical_dist(raw_actions) 279 actions = actions_dist.sample(seed=key) 280 return actions, PyTreeDict()
281
[docs] 282 def evaluate_actions( 283 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 284 ) -> tuple[Action, PolicyExtraInfo]: 285 obs = sample_batch.obs 286 if self.normalize_obs: 287 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 288 289 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 290 actions_dist = get_categorical_dist(raw_actions) 291 actions = actions_dist.mode() 292 return actions, PyTreeDict()
293
[docs] 294 def alpha_loss( 295 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 296 ) -> LossDict: 297 obs = sample_batch.obs 298 if self.normalize_obs: 299 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 300 301 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 302 actions_dist = get_categorical_dist(raw_actions) 303 entropy = actions_dist.entropy() 304 305 target_entropy = agent_state.extra_state.target_entropy 306 # official impl: 307 alpha = jnp.exp(agent_state.params.log_alpha) 308 alpha_loss = -jnp.mean(alpha * jax.lax.stop_gradient(target_entropy - entropy)) 309 310 return PyTreeDict( 311 alpha_loss=alpha_loss, 312 )
313
[docs] 314 def actor_loss( 315 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 316 ) -> LossDict: 317 actor_key, entropy_key = jax.random.split(key, 2) 318 obs = sample_batch.obs 319 if self.normalize_obs: 320 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 321 322 alpha = jnp.exp(agent_state.params.log_alpha) 323 324 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs) 325 actions_dist = get_categorical_dist(raw_actions) 326 entropy = actions_dist.entropy() 327 actions_prob = nn.softmax(raw_actions) 328 329 # [B, 2, n] 330 qs = self.critic_network.apply(agent_state.params.critic_params, obs) 331 qs_min = jnp.min(qs, axis=-2) 332 qs_estimate = jnp.sum(qs_min * actions_prob, axis=-1) 333 actor_loss = -jnp.mean(alpha * entropy + qs_estimate) 334 335 return PyTreeDict(actor_loss=actor_loss, entropy=entropy.mean())
336
[docs] 337 def critic_loss( 338 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 339 ) -> LossDict: 340 obs = sample_batch.obs 341 next_obs = sample_batch.extras.env_extras.ori_obs 342 if self.normalize_obs: 343 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 344 next_obs = self.obs_preprocessor( 345 next_obs, agent_state.obs_preprocessor_state 346 ) 347 348 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 349 350 alpha = jnp.exp(agent_state.params.log_alpha) 351 352 # [B, 2, n] 353 qs = self.critic_network.apply(agent_state.params.critic_params, obs) 354 qs = jnp.take_along_axis( 355 qs, 356 sample_batch.actions.reshape(-1, 1, 1), 357 axis=-1, 358 ).squeeze(-1) 359 360 next_raw_actions = self.actor_network.apply( 361 agent_state.params.actor_params, next_obs 362 ) 363 next_actions_prob = nn.softmax(next_raw_actions) 364 next_actions_logp = nn.log_softmax(next_raw_actions) 365 # [B, 2, n] 366 next_qs = self.critic_network.apply( 367 agent_state.params.target_critic_params, next_obs 368 ) 369 next_qs_min = jnp.min(next_qs, axis=-2) # [B, n] 370 next_qs_estimate = jnp.sum( 371 next_actions_prob * (next_qs_min - alpha * next_actions_logp), axis=-1 372 ) # [B] 373 374 qs_target = sample_batch.rewards + discounts * next_qs_estimate 375 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2)) 376 377 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean() 378 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
379 380
[docs] 381def make_mlp_sac_agent( 382 action_space: Space, 383 num_critics: int = 2, 384 critic_hidden_layer_sizes: tuple[int] = (256, 256), 385 actor_hidden_layer_sizes: tuple[int] = (256, 256), 386 init_alpha: float = 1.0, 387 discount: float = 0.99, 388 target_entropy_ratio: float = 0.98, 389 normalize_obs: bool = False, 390 policy_obs_key: str = "", 391 value_obs_key: str = "", 392): 393 if isinstance(action_space, Box): 394 action_size = action_space.shape[0] * 2 395 continuous_action = True 396 elif isinstance(action_space, Discrete): 397 action_size = action_space.n 398 continuous_action = False 399 else: 400 raise NotImplementedError(f"Unsupported action space: {action_space}") 401 402 actor_network = make_policy_network( 403 action_size=action_size, # mean+std 404 hidden_layer_sizes=actor_hidden_layer_sizes, 405 obs_key=policy_obs_key, 406 ) 407 408 if normalize_obs: 409 obs_preprocessor = running_statistics.normalize 410 else: 411 obs_preprocessor = None 412 413 if continuous_action: 414 critic_network = make_q_network( 415 n_stack=num_critics, 416 hidden_layer_sizes=critic_hidden_layer_sizes, 417 obs_key=value_obs_key, 418 ) 419 420 return SACAgent( 421 critic_network=critic_network, 422 actor_network=actor_network, 423 obs_preprocessor=obs_preprocessor, 424 init_alpha=init_alpha, 425 discount=discount, 426 ) 427 else: 428 critic_network = make_discrete_q_network( 429 action_size=action_size, 430 n_stack=2, 431 hidden_layer_sizes=critic_hidden_layer_sizes, 432 obs_key=value_obs_key, 433 ) 434 return SACDiscreteAgent( 435 critic_network=critic_network, 436 actor_network=actor_network, 437 obs_preprocessor=obs_preprocessor, 438 init_alpha=init_alpha, 439 discount=discount, 440 target_entropy_ratio=target_entropy_ratio, 441 )
442 443
[docs] 444class SACWorkflow(OffPolicyWorkflowTemplate):
[docs] 445 @classmethod 446 def name(cls): 447 return "SAC"
448 449 @classmethod 450 def _build_from_config(cls, config: DictConfig): 451 env = create_env( 452 config.env, 453 episode_length=config.env.max_episode_steps, 454 parallel=config.num_envs, 455 autoreset_mode=AutoresetMode.NORMAL, 456 record_ori_obs=True, 457 ) 458 459 agent = make_mlp_sac_agent( 460 action_space=env.action_space, 461 num_critics=config.agent_network.num_critics, 462 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 463 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 464 init_alpha=config.alpha, 465 discount=config.discount, 466 normalize_obs=config.normalize_obs, 467 target_entropy_ratio=config.target_entropy_ratio, 468 policy_obs_key=config.agent_network.policy_obs_key, 469 value_obs_key=config.agent_network.value_obs_key, 470 ) 471 472 # TODO: use different lr for critic and actor 473 if ( 474 config.optimizer.grad_clip_norm is not None 475 and config.optimizer.grad_clip_norm > 0 476 ): 477 optimizer = optax.chain( 478 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 479 optax.adam(config.optimizer.lr), 480 ) 481 else: 482 optimizer = optax.adam(config.optimizer.lr) 483 484 replay_buffer = ReplayBuffer( 485 capacity=config.replay_buffer_capacity, 486 min_sample_timesteps=max( 487 config.batch_size, config.learning_start_timesteps 488 ), 489 sample_batch_size=config.batch_size, 490 ) 491 492 eval_env = create_env( 493 config.env, 494 episode_length=config.env.max_episode_steps, 495 parallel=config.num_eval_envs, 496 autoreset_mode=AutoresetMode.DISABLED, 497 ) 498 499 evaluator = Evaluator( 500 env=eval_env, 501 action_fn=agent.evaluate_actions, 502 max_episode_steps=config.env.max_episode_steps, 503 ) 504 505 return cls( 506 env, 507 agent, 508 optimizer, 509 evaluator, 510 replay_buffer, 511 config, 512 ) 513 514 def _setup_agent_and_optimizer( 515 self, key: chex.PRNGKey 516 ) -> tuple[AgentState, chex.ArrayTree]: 517 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 518 opt_state = PyTreeDict( 519 dict( 520 actor=self.optimizer.init(agent_state.params.actor_params), 521 critic=self.optimizer.init(agent_state.params.critic_params), 522 ) 523 ) 524 if self.config.adaptive_alpha: 525 opt_state = opt_state.replace( 526 alpha=self.optimizer.init(agent_state.params.log_alpha) 527 ) 528 529 return agent_state, opt_state 530
[docs] 531 def step(self, state: State) -> tuple[MetricBase, State]: 532 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 533 534 # the trajectory [T, B, ...] 535 trajectory, env_state = rollout( 536 env_fn=self.env.step, 537 action_fn=self.agent.compute_actions, 538 env_state=state.env_state, 539 agent_state=state.agent_state, 540 key=rollout_key, 541 rollout_length=self.config.rollout_length, 542 env_extra_fields=("ori_obs", "termination"), 543 ) 544 545 trajectory_dones = trajectory.dones 546 trajectory = clean_trajectory(trajectory) 547 trajectory = flatten_rollout_trajectory(trajectory) 548 trajectory = tree_stop_gradient(trajectory) 549 550 agent_state = state.agent_state 551 if agent_state.obs_preprocessor_state is not None: 552 agent_state = agent_state.replace( 553 obs_preprocessor_state=running_statistics.update( 554 agent_state.obs_preprocessor_state, 555 trajectory.obs, 556 dp_axis_name=self.dp_axis_name, 557 ) 558 ) 559 560 replay_buffer_state = self.replay_buffer.add( 561 state.replay_buffer_state, trajectory 562 ) 563 564 def critic_loss_fn(agent_state, sample_batch, key): 565 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 566 567 loss = loss_dict.critic_loss 568 return loss, loss_dict 569 570 def actor_loss_fn(agent_state, sample_batch, key): 571 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 572 573 loss = loss_dict.actor_loss 574 return loss, loss_dict 575 576 def alpha_loss_fn(agent_state, sample_batch, key): 577 loss_dict = self.agent.alpha_loss(agent_state, sample_batch, key) 578 579 loss = loss_dict.alpha_loss 580 return loss, loss_dict 581 582 critic_update_fn = agent_gradient_update( 583 critic_loss_fn, 584 self.optimizer, 585 dp_axis_name=self.dp_axis_name, 586 has_aux=True, 587 attach_fn=lambda agent_state, critic_params: agent_state.replace( 588 params=agent_state.params.replace(critic_params=critic_params) 589 ), 590 detach_fn=lambda agent_state: agent_state.params.critic_params, 591 ) 592 593 actor_update_fn = agent_gradient_update( 594 actor_loss_fn, 595 self.optimizer, 596 dp_axis_name=self.dp_axis_name, 597 has_aux=True, 598 attach_fn=lambda agent_state, actor_params: agent_state.replace( 599 params=agent_state.params.replace(actor_params=actor_params) 600 ), 601 detach_fn=lambda agent_state: agent_state.params.actor_params, 602 ) 603 604 alpha_update_fn = agent_gradient_update( 605 alpha_loss_fn, 606 self.optimizer, 607 dp_axis_name=self.dp_axis_name, 608 has_aux=True, 609 attach_fn=lambda agent_state, log_alpha: agent_state.replace( 610 params=agent_state.params.replace(log_alpha=log_alpha) 611 ), 612 detach_fn=lambda agent_state: agent_state.params.log_alpha, 613 ) 614 615 def _sample_and_update_fn(carry, unused_t): 616 key, agent_state, opt_state = carry 617 618 critic_opt_state = opt_state.critic 619 actor_opt_state = opt_state.actor 620 621 key, critic_key, actor_key, alpha_key, rb_key = jax.random.split(key, num=5) 622 623 if self.config.actor_update_interval - 1 > 0: 624 625 def _sample_and_update_critic_fn(carry, unused_t): 626 key, agent_state, critic_opt_state = carry 627 628 key, rb_key, critic_key = jax.random.split(key, num=3) 629 # it's safe to use read-only replay_buffer_state here. 630 sample_batch = self.replay_buffer.sample( 631 replay_buffer_state, rb_key 632 ) 633 634 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 635 critic_update_fn( 636 critic_opt_state, agent_state, sample_batch, critic_key 637 ) 638 ) 639 640 return (key, agent_state, critic_opt_state), None 641 642 key, critic_multiple_update_key = jax.random.split(key) 643 644 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 645 _sample_and_update_critic_fn, 646 (critic_multiple_update_key, agent_state, critic_opt_state), 647 (), 648 length=self.config.actor_update_interval - 1, 649 ) 650 651 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 652 653 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 654 critic_update_fn( 655 critic_opt_state, agent_state, sample_batch, critic_key 656 ) 657 ) 658 659 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 660 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key) 661 ) 662 663 opt_state = opt_state.replace( 664 actor=actor_opt_state, critic=critic_opt_state 665 ) 666 667 if self.config.adaptive_alpha: 668 # we follow the update order of the official implementation: 669 # critic -> actor -> alpha 670 alpha_opt_state = opt_state.alpha 671 (alpha_loss, alpha_loss_dict), agent_state, alpha_opt_state = ( 672 alpha_update_fn( 673 alpha_opt_state, agent_state, sample_batch, alpha_key 674 ) 675 ) 676 opt_state = opt_state.replace(alpha=alpha_opt_state) 677 678 alpha_loss_dict = alpha_loss_dict.replace( 679 log_alpha=agent_state.params.log_alpha, 680 alpha=jnp.exp(agent_state.params.log_alpha), 681 ) 682 683 res = ( 684 critic_loss, 685 actor_loss, 686 alpha_loss, 687 critic_loss_dict, 688 actor_loss_dict, 689 alpha_loss_dict, 690 ) 691 else: 692 res = (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict) 693 694 target_critic_params = soft_target_update( 695 agent_state.params.target_critic_params, 696 agent_state.params.critic_params, 697 self.config.tau, 698 ) 699 agent_state = agent_state.replace( 700 params=agent_state.params.replace( 701 target_critic_params=target_critic_params 702 ) 703 ) 704 705 return (key, agent_state, opt_state), res 706 707 if self.config.adaptive_alpha: 708 ( 709 (_, agent_state, opt_state), 710 ( 711 critic_loss, 712 actor_loss, 713 alpha_loss, 714 critic_loss_dict, 715 actor_loss_dict, 716 alpha_loss_dict, 717 ), 718 ) = scan_and_mean( 719 _sample_and_update_fn, 720 (learn_key, agent_state, state.opt_state), 721 (), 722 length=self.config.num_updates_per_iter, 723 ) 724 train_metrics = SACTrainMetric( 725 actor_loss=actor_loss, 726 critic_loss=critic_loss, 727 alpha_loss=alpha_loss, 728 raw_loss_dict=PyTreeDict( 729 {**critic_loss_dict, **actor_loss_dict, **alpha_loss_dict} 730 ), 731 ).all_reduce(dp_axis_name=self.dp_axis_name) 732 else: 733 ( 734 (_, agent_state, opt_state), 735 ( 736 critic_loss, 737 actor_loss, 738 critic_loss_dict, 739 actor_loss_dict, 740 ), 741 ) = scan_and_mean( 742 _sample_and_update_fn, 743 (learn_key, agent_state, state.opt_state), 744 (), 745 length=self.config.num_updates_per_iter, 746 ) 747 train_metrics = SACTrainMetric( 748 actor_loss=actor_loss, 749 critic_loss=critic_loss, 750 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 751 ).all_reduce(dp_axis_name=self.dp_axis_name) 752 753 # calculate the number of timestep 754 sampled_timesteps = psum( 755 jnp.uint32(self.config.rollout_length * self.config.num_envs), 756 axis_name=self.dp_axis_name, 757 ) 758 759 sampled_epsiodes = psum( 760 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 761 ) 762 763 # iterations is the number of updates of the agent 764 workflow_metrics = state.metrics.replace( 765 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 766 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 767 iterations=state.metrics.iterations + 1, 768 ).all_reduce(dp_axis_name=self.dp_axis_name) 769 770 return train_metrics, state.replace( 771 key=key, 772 metrics=workflow_metrics, 773 agent_state=agent_state, 774 env_state=env_state, 775 replay_buffer_state=replay_buffer_state, 776 opt_state=opt_state, 777 )