1import logging
2from typing import Any
3from omegaconf import DictConfig
4
5import chex
6import flax.linen as nn
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11
12from evorl.replay_buffers import ReplayBuffer
13from evorl.distributed import agent_gradient_update, psum, pmean
14from evorl.distribution import get_tanh_norm_dist, get_categorical_dist
15from evorl.envs import AutoresetMode, Box, create_env, Space, Discrete
16from evorl.evaluators import Evaluator
17from evorl.metrics import MetricBase, metric_field
18from evorl.networks import make_policy_network, make_q_network, make_discrete_q_network
19from evorl.rollout import rollout
20from evorl.sample_batch import SampleBatch
21from evorl.types import (
22 Action,
23 LossDict,
24 Params,
25 PolicyExtraInfo,
26 PyTreeData,
27 PyTreeDict,
28 State,
29 pytree_field,
30)
31from evorl.utils import running_statistics
32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
34
35from evorl.agent import Agent, AgentState
36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
37
38logger = logging.getLogger(__name__)
39
40
[docs]
41class SACTrainMetric(MetricBase):
42 critic_loss: chex.Array
43 actor_loss: chex.Array
44 alpha_loss: chex.Array | None = None
45 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
46
47
[docs]
48class SACNetworkParams(PyTreeData):
49 critic_params: Params
50 target_critic_params: Params
51 actor_params: Params
52 log_alpha: Params
53
54
[docs]
55class SACAgent(Agent):
56 critic_network: nn.Module
57 actor_network: nn.Module
58 obs_preprocessor: Any = pytree_field(default=None, static=True)
59
60 init_alpha: float = 1.0
61 discount: float = 0.99
62
63 @property
64 def normalize_obs(self):
65 return self.obs_preprocessor is not None
66
[docs]
67 def init(
68 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
69 ) -> AgentState:
70 key, critic_key, actor_key = jax.random.split(key, num=3)
71
72 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
73 dummy_action = action_space.sample(key)[None, ...]
74
75 critic_params = self.critic_network.init(critic_key, dummy_obs, dummy_action)
76 target_critic_params = critic_params
77
78 actor_params = self.actor_network.init(actor_key, dummy_obs)
79
80 log_alpha = jnp.log(jnp.float32(self.init_alpha))
81
82 params_state = SACNetworkParams(
83 critic_params=critic_params,
84 target_critic_params=target_critic_params,
85 actor_params=actor_params,
86 log_alpha=log_alpha,
87 )
88
89 if self.normalize_obs:
90 # Note: statistics are broadcasted to [T*B]
91 obs_preprocessor_state = running_statistics.init_state(
92 tree_get(dummy_obs, 0)
93 )
94 else:
95 obs_preprocessor_state = None
96
97 target_entropy = -jnp.prod(jnp.array(action_space.shape, dtype=jnp.float32))
98
99 return AgentState(
100 params=params_state,
101 obs_preprocessor_state=obs_preprocessor_state,
102 extra_state=PyTreeDict(target_entropy=target_entropy), # the constant
103 )
104
[docs]
105 def compute_actions(
106 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
107 ) -> tuple[Action, PolicyExtraInfo]:
108 obs = sample_batch.obs
109 if self.normalize_obs:
110 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
111
112 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
113 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
114 actions = actions_dist.sample(seed=key)
115 return actions, PyTreeDict()
116
[docs]
117 def evaluate_actions(
118 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
119 ) -> tuple[Action, PolicyExtraInfo]:
120 obs = sample_batch.obs
121 if self.normalize_obs:
122 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
123
124 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
125 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
126 actions = actions_dist.mode()
127 return actions, PyTreeDict()
128
[docs]
129 def alpha_loss(
130 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
131 ) -> LossDict:
132 obs = sample_batch.obs
133 if self.normalize_obs:
134 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
135
136 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
137 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
138 actions = actions_dist.sample(seed=key)
139 actions_logp = actions_dist.log_prob(actions)
140
141 target_entropy = agent_state.extra_state.target_entropy
142 # official impl:
143 alpha = jnp.exp(agent_state.params.log_alpha)
144 alpha_loss = jnp.mean(
145 -alpha * jax.lax.stop_gradient(actions_logp + target_entropy)
146 )
147
148 # another impl: see stable-baselines3/issues/36
149 # alpha_loss = (- agent_state.params.log_alpha *
150 # jax.lax.stop_gradient(actions_logp + target_entropy)).mean()
151
152 return PyTreeDict(
153 alpha_loss=alpha_loss, log_alpha=agent_state.params.log_alpha, alpha=alpha
154 )
155
[docs]
156 def actor_loss(
157 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
158 ) -> LossDict:
159 actor_key, entropy_key = jax.random.split(key, 2)
160 obs = sample_batch.obs
161 if self.normalize_obs:
162 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
163
164 alpha = jnp.exp(agent_state.params.log_alpha)
165
166 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
167 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
168 actions = actions_dist.sample(seed=actor_key)
169 actions_logp = actions_dist.log_prob(actions)
170
171 # [B, 2]
172 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions)
173 qs_min = jnp.min(qs, axis=-1)
174 actor_loss = jnp.mean(alpha * actions_logp - qs_min)
175 entropy = actions_dist.entropy(seed=entropy_key).mean()
176
177 return PyTreeDict(actor_loss=actor_loss, entropy=entropy)
178
[docs]
179 def critic_loss(
180 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
181 ) -> LossDict:
182 obs = sample_batch.obs
183 next_obs = sample_batch.extras.env_extras.ori_obs
184 if self.normalize_obs:
185 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
186 next_obs = self.obs_preprocessor(
187 next_obs, agent_state.obs_preprocessor_state
188 )
189
190 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
191
192 alpha = jnp.exp(agent_state.params.log_alpha)
193
194 # [B, 2]
195 qs = self.critic_network.apply(
196 agent_state.params.critic_params, obs, sample_batch.actions
197 )
198
199 next_raw_actions = self.actor_network.apply(
200 agent_state.params.actor_params, next_obs
201 )
202 next_actions_dist = get_tanh_norm_dist(*jnp.split(next_raw_actions, 2, axis=-1))
203 next_actions = next_actions_dist.sample(seed=key)
204 next_actions_logp = next_actions_dist.log_prob(next_actions)
205 # [B, 2]
206 next_qs = self.critic_network.apply(
207 agent_state.params.target_critic_params, next_obs, next_actions
208 )
209 qs_target = sample_batch.rewards + discounts * (
210 jnp.min(next_qs, axis=-1) - alpha * next_actions_logp
211 )
212 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2))
213
214 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean()
215 return PyTreeDict(critic_loss=q_loss)
216
217
[docs]
218class SACDiscreteAgent(Agent):
219 critic_network: nn.Module
220 actor_network: nn.Module
221 obs_preprocessor: Any = pytree_field(default=None, static=True)
222
223 init_alpha: float = 1.0
224 discount: float = 0.99
225 target_entropy_ratio: float = 0.98
226
227 @property
228 def normalize_obs(self):
229 return self.obs_preprocessor is not None
230
[docs]
231 def init(
232 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
233 ) -> AgentState:
234 key, critic_key, actor_key = jax.random.split(key, num=3)
235
236 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
237
238 critic_params = self.critic_network.init(critic_key, dummy_obs)
239 target_critic_params = critic_params
240
241 actor_params = self.actor_network.init(actor_key, dummy_obs)
242
243 log_alpha = jnp.log(jnp.float32(self.init_alpha))
244
245 params_state = SACNetworkParams(
246 critic_params=critic_params,
247 target_critic_params=target_critic_params,
248 actor_params=actor_params,
249 log_alpha=log_alpha,
250 )
251
252 if self.normalize_obs:
253 # Note: statistics are broadcasted to [T*B]
254 obs_preprocessor_state = running_statistics.init_state(
255 tree_get(dummy_obs, 0)
256 )
257 else:
258 obs_preprocessor_state = None
259
260 target_entropy = self.target_entropy_ratio * jnp.log(
261 jnp.float32(action_space.n)
262 )
263
264 return AgentState(
265 params=params_state,
266 obs_preprocessor_state=obs_preprocessor_state,
267 extra_state=PyTreeDict(target_entropy=target_entropy), # the constant
268 )
269
[docs]
270 def compute_actions(
271 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
272 ) -> tuple[Action, PolicyExtraInfo]:
273 obs = sample_batch.obs
274 if self.normalize_obs:
275 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
276
277 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
278 actions_dist = get_categorical_dist(raw_actions)
279 actions = actions_dist.sample(seed=key)
280 return actions, PyTreeDict()
281
[docs]
282 def evaluate_actions(
283 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
284 ) -> tuple[Action, PolicyExtraInfo]:
285 obs = sample_batch.obs
286 if self.normalize_obs:
287 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
288
289 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
290 actions_dist = get_categorical_dist(raw_actions)
291 actions = actions_dist.mode()
292 return actions, PyTreeDict()
293
[docs]
294 def alpha_loss(
295 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
296 ) -> LossDict:
297 obs = sample_batch.obs
298 if self.normalize_obs:
299 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
300
301 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
302 actions_dist = get_categorical_dist(raw_actions)
303 entropy = actions_dist.entropy()
304
305 target_entropy = agent_state.extra_state.target_entropy
306 # official impl:
307 alpha = jnp.exp(agent_state.params.log_alpha)
308 alpha_loss = -jnp.mean(alpha * jax.lax.stop_gradient(target_entropy - entropy))
309
310 return PyTreeDict(
311 alpha_loss=alpha_loss,
312 )
313
[docs]
314 def actor_loss(
315 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
316 ) -> LossDict:
317 actor_key, entropy_key = jax.random.split(key, 2)
318 obs = sample_batch.obs
319 if self.normalize_obs:
320 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
321
322 alpha = jnp.exp(agent_state.params.log_alpha)
323
324 raw_actions = self.actor_network.apply(agent_state.params.actor_params, obs)
325 actions_dist = get_categorical_dist(raw_actions)
326 entropy = actions_dist.entropy()
327 actions_prob = nn.softmax(raw_actions)
328
329 # [B, 2, n]
330 qs = self.critic_network.apply(agent_state.params.critic_params, obs)
331 qs_min = jnp.min(qs, axis=-2)
332 qs_estimate = jnp.sum(qs_min * actions_prob, axis=-1)
333 actor_loss = -jnp.mean(alpha * entropy + qs_estimate)
334
335 return PyTreeDict(actor_loss=actor_loss, entropy=entropy.mean())
336
[docs]
337 def critic_loss(
338 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
339 ) -> LossDict:
340 obs = sample_batch.obs
341 next_obs = sample_batch.extras.env_extras.ori_obs
342 if self.normalize_obs:
343 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
344 next_obs = self.obs_preprocessor(
345 next_obs, agent_state.obs_preprocessor_state
346 )
347
348 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
349
350 alpha = jnp.exp(agent_state.params.log_alpha)
351
352 # [B, 2, n]
353 qs = self.critic_network.apply(agent_state.params.critic_params, obs)
354 qs = jnp.take_along_axis(
355 qs,
356 sample_batch.actions.reshape(-1, 1, 1),
357 axis=-1,
358 ).squeeze(-1)
359
360 next_raw_actions = self.actor_network.apply(
361 agent_state.params.actor_params, next_obs
362 )
363 next_actions_prob = nn.softmax(next_raw_actions)
364 next_actions_logp = nn.log_softmax(next_raw_actions)
365 # [B, 2, n]
366 next_qs = self.critic_network.apply(
367 agent_state.params.target_critic_params, next_obs
368 )
369 next_qs_min = jnp.min(next_qs, axis=-2) # [B, n]
370 next_qs_estimate = jnp.sum(
371 next_actions_prob * (next_qs_min - alpha * next_actions_logp), axis=-1
372 ) # [B]
373
374 qs_target = sample_batch.rewards + discounts * next_qs_estimate
375 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2))
376
377 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean()
378 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
379
380
[docs]
381def make_mlp_sac_agent(
382 action_space: Space,
383 num_critics: int = 2,
384 critic_hidden_layer_sizes: tuple[int] = (256, 256),
385 actor_hidden_layer_sizes: tuple[int] = (256, 256),
386 init_alpha: float = 1.0,
387 discount: float = 0.99,
388 target_entropy_ratio: float = 0.98,
389 normalize_obs: bool = False,
390 policy_obs_key: str = "",
391 value_obs_key: str = "",
392):
393 if isinstance(action_space, Box):
394 action_size = action_space.shape[0] * 2
395 continuous_action = True
396 elif isinstance(action_space, Discrete):
397 action_size = action_space.n
398 continuous_action = False
399 else:
400 raise NotImplementedError(f"Unsupported action space: {action_space}")
401
402 actor_network = make_policy_network(
403 action_size=action_size, # mean+std
404 hidden_layer_sizes=actor_hidden_layer_sizes,
405 obs_key=policy_obs_key,
406 )
407
408 if normalize_obs:
409 obs_preprocessor = running_statistics.normalize
410 else:
411 obs_preprocessor = None
412
413 if continuous_action:
414 critic_network = make_q_network(
415 n_stack=num_critics,
416 hidden_layer_sizes=critic_hidden_layer_sizes,
417 obs_key=value_obs_key,
418 )
419
420 return SACAgent(
421 critic_network=critic_network,
422 actor_network=actor_network,
423 obs_preprocessor=obs_preprocessor,
424 init_alpha=init_alpha,
425 discount=discount,
426 )
427 else:
428 critic_network = make_discrete_q_network(
429 action_size=action_size,
430 n_stack=2,
431 hidden_layer_sizes=critic_hidden_layer_sizes,
432 obs_key=value_obs_key,
433 )
434 return SACDiscreteAgent(
435 critic_network=critic_network,
436 actor_network=actor_network,
437 obs_preprocessor=obs_preprocessor,
438 init_alpha=init_alpha,
439 discount=discount,
440 target_entropy_ratio=target_entropy_ratio,
441 )
442
443
[docs]
444class SACWorkflow(OffPolicyWorkflowTemplate):
[docs]
445 @classmethod
446 def name(cls):
447 return "SAC"
448
449 @classmethod
450 def _build_from_config(cls, config: DictConfig):
451 env = create_env(
452 config.env,
453 episode_length=config.env.max_episode_steps,
454 parallel=config.num_envs,
455 autoreset_mode=AutoresetMode.NORMAL,
456 record_ori_obs=True,
457 )
458
459 agent = make_mlp_sac_agent(
460 action_space=env.action_space,
461 num_critics=config.agent_network.num_critics,
462 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
463 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
464 init_alpha=config.alpha,
465 discount=config.discount,
466 normalize_obs=config.normalize_obs,
467 target_entropy_ratio=config.target_entropy_ratio,
468 policy_obs_key=config.agent_network.policy_obs_key,
469 value_obs_key=config.agent_network.value_obs_key,
470 )
471
472 # TODO: use different lr for critic and actor
473 if (
474 config.optimizer.grad_clip_norm is not None
475 and config.optimizer.grad_clip_norm > 0
476 ):
477 optimizer = optax.chain(
478 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
479 optax.adam(config.optimizer.lr),
480 )
481 else:
482 optimizer = optax.adam(config.optimizer.lr)
483
484 replay_buffer = ReplayBuffer(
485 capacity=config.replay_buffer_capacity,
486 min_sample_timesteps=max(
487 config.batch_size, config.learning_start_timesteps
488 ),
489 sample_batch_size=config.batch_size,
490 )
491
492 eval_env = create_env(
493 config.env,
494 episode_length=config.env.max_episode_steps,
495 parallel=config.num_eval_envs,
496 autoreset_mode=AutoresetMode.DISABLED,
497 )
498
499 evaluator = Evaluator(
500 env=eval_env,
501 action_fn=agent.evaluate_actions,
502 max_episode_steps=config.env.max_episode_steps,
503 )
504
505 return cls(
506 env,
507 agent,
508 optimizer,
509 evaluator,
510 replay_buffer,
511 config,
512 )
513
514 def _setup_agent_and_optimizer(
515 self, key: chex.PRNGKey
516 ) -> tuple[AgentState, chex.ArrayTree]:
517 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
518 opt_state = PyTreeDict(
519 dict(
520 actor=self.optimizer.init(agent_state.params.actor_params),
521 critic=self.optimizer.init(agent_state.params.critic_params),
522 )
523 )
524 if self.config.adaptive_alpha:
525 opt_state = opt_state.replace(
526 alpha=self.optimizer.init(agent_state.params.log_alpha)
527 )
528
529 return agent_state, opt_state
530
[docs]
531 def step(self, state: State) -> tuple[MetricBase, State]:
532 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
533
534 # the trajectory [T, B, ...]
535 trajectory, env_state = rollout(
536 env_fn=self.env.step,
537 action_fn=self.agent.compute_actions,
538 env_state=state.env_state,
539 agent_state=state.agent_state,
540 key=rollout_key,
541 rollout_length=self.config.rollout_length,
542 env_extra_fields=("ori_obs", "termination"),
543 )
544
545 trajectory_dones = trajectory.dones
546 trajectory = clean_trajectory(trajectory)
547 trajectory = flatten_rollout_trajectory(trajectory)
548 trajectory = tree_stop_gradient(trajectory)
549
550 agent_state = state.agent_state
551 if agent_state.obs_preprocessor_state is not None:
552 agent_state = agent_state.replace(
553 obs_preprocessor_state=running_statistics.update(
554 agent_state.obs_preprocessor_state,
555 trajectory.obs,
556 dp_axis_name=self.dp_axis_name,
557 )
558 )
559
560 replay_buffer_state = self.replay_buffer.add(
561 state.replay_buffer_state, trajectory
562 )
563
564 def critic_loss_fn(agent_state, sample_batch, key):
565 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
566
567 loss = loss_dict.critic_loss
568 return loss, loss_dict
569
570 def actor_loss_fn(agent_state, sample_batch, key):
571 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
572
573 loss = loss_dict.actor_loss
574 return loss, loss_dict
575
576 def alpha_loss_fn(agent_state, sample_batch, key):
577 loss_dict = self.agent.alpha_loss(agent_state, sample_batch, key)
578
579 loss = loss_dict.alpha_loss
580 return loss, loss_dict
581
582 critic_update_fn = agent_gradient_update(
583 critic_loss_fn,
584 self.optimizer,
585 dp_axis_name=self.dp_axis_name,
586 has_aux=True,
587 attach_fn=lambda agent_state, critic_params: agent_state.replace(
588 params=agent_state.params.replace(critic_params=critic_params)
589 ),
590 detach_fn=lambda agent_state: agent_state.params.critic_params,
591 )
592
593 actor_update_fn = agent_gradient_update(
594 actor_loss_fn,
595 self.optimizer,
596 dp_axis_name=self.dp_axis_name,
597 has_aux=True,
598 attach_fn=lambda agent_state, actor_params: agent_state.replace(
599 params=agent_state.params.replace(actor_params=actor_params)
600 ),
601 detach_fn=lambda agent_state: agent_state.params.actor_params,
602 )
603
604 alpha_update_fn = agent_gradient_update(
605 alpha_loss_fn,
606 self.optimizer,
607 dp_axis_name=self.dp_axis_name,
608 has_aux=True,
609 attach_fn=lambda agent_state, log_alpha: agent_state.replace(
610 params=agent_state.params.replace(log_alpha=log_alpha)
611 ),
612 detach_fn=lambda agent_state: agent_state.params.log_alpha,
613 )
614
615 def _sample_and_update_fn(carry, unused_t):
616 key, agent_state, opt_state = carry
617
618 critic_opt_state = opt_state.critic
619 actor_opt_state = opt_state.actor
620
621 key, critic_key, actor_key, alpha_key, rb_key = jax.random.split(key, num=5)
622
623 if self.config.actor_update_interval - 1 > 0:
624
625 def _sample_and_update_critic_fn(carry, unused_t):
626 key, agent_state, critic_opt_state = carry
627
628 key, rb_key, critic_key = jax.random.split(key, num=3)
629 # it's safe to use read-only replay_buffer_state here.
630 sample_batch = self.replay_buffer.sample(
631 replay_buffer_state, rb_key
632 )
633
634 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
635 critic_update_fn(
636 critic_opt_state, agent_state, sample_batch, critic_key
637 )
638 )
639
640 return (key, agent_state, critic_opt_state), None
641
642 key, critic_multiple_update_key = jax.random.split(key)
643
644 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
645 _sample_and_update_critic_fn,
646 (critic_multiple_update_key, agent_state, critic_opt_state),
647 (),
648 length=self.config.actor_update_interval - 1,
649 )
650
651 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
652
653 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
654 critic_update_fn(
655 critic_opt_state, agent_state, sample_batch, critic_key
656 )
657 )
658
659 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
660 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key)
661 )
662
663 opt_state = opt_state.replace(
664 actor=actor_opt_state, critic=critic_opt_state
665 )
666
667 if self.config.adaptive_alpha:
668 # we follow the update order of the official implementation:
669 # critic -> actor -> alpha
670 alpha_opt_state = opt_state.alpha
671 (alpha_loss, alpha_loss_dict), agent_state, alpha_opt_state = (
672 alpha_update_fn(
673 alpha_opt_state, agent_state, sample_batch, alpha_key
674 )
675 )
676 opt_state = opt_state.replace(alpha=alpha_opt_state)
677
678 alpha_loss_dict = alpha_loss_dict.replace(
679 log_alpha=agent_state.params.log_alpha,
680 alpha=jnp.exp(agent_state.params.log_alpha),
681 )
682
683 res = (
684 critic_loss,
685 actor_loss,
686 alpha_loss,
687 critic_loss_dict,
688 actor_loss_dict,
689 alpha_loss_dict,
690 )
691 else:
692 res = (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict)
693
694 target_critic_params = soft_target_update(
695 agent_state.params.target_critic_params,
696 agent_state.params.critic_params,
697 self.config.tau,
698 )
699 agent_state = agent_state.replace(
700 params=agent_state.params.replace(
701 target_critic_params=target_critic_params
702 )
703 )
704
705 return (key, agent_state, opt_state), res
706
707 if self.config.adaptive_alpha:
708 (
709 (_, agent_state, opt_state),
710 (
711 critic_loss,
712 actor_loss,
713 alpha_loss,
714 critic_loss_dict,
715 actor_loss_dict,
716 alpha_loss_dict,
717 ),
718 ) = scan_and_mean(
719 _sample_and_update_fn,
720 (learn_key, agent_state, state.opt_state),
721 (),
722 length=self.config.num_updates_per_iter,
723 )
724 train_metrics = SACTrainMetric(
725 actor_loss=actor_loss,
726 critic_loss=critic_loss,
727 alpha_loss=alpha_loss,
728 raw_loss_dict=PyTreeDict(
729 {**critic_loss_dict, **actor_loss_dict, **alpha_loss_dict}
730 ),
731 ).all_reduce(dp_axis_name=self.dp_axis_name)
732 else:
733 (
734 (_, agent_state, opt_state),
735 (
736 critic_loss,
737 actor_loss,
738 critic_loss_dict,
739 actor_loss_dict,
740 ),
741 ) = scan_and_mean(
742 _sample_and_update_fn,
743 (learn_key, agent_state, state.opt_state),
744 (),
745 length=self.config.num_updates_per_iter,
746 )
747 train_metrics = SACTrainMetric(
748 actor_loss=actor_loss,
749 critic_loss=critic_loss,
750 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
751 ).all_reduce(dp_axis_name=self.dp_axis_name)
752
753 # calculate the number of timestep
754 sampled_timesteps = psum(
755 jnp.uint32(self.config.rollout_length * self.config.num_envs),
756 axis_name=self.dp_axis_name,
757 )
758
759 sampled_epsiodes = psum(
760 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
761 )
762
763 # iterations is the number of updates of the agent
764 workflow_metrics = state.metrics.replace(
765 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
766 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
767 iterations=state.metrics.iterations + 1,
768 ).all_reduce(dp_axis_name=self.dp_axis_name)
769
770 return train_metrics, state.replace(
771 key=key,
772 metrics=workflow_metrics,
773 agent_state=agent_state,
774 env_state=env_state,
775 replay_buffer_state=replay_buffer_state,
776 opt_state=opt_state,
777 )