1import logging
2from typing import Any, Sequence
3import math
4
5import chex
6import flax.linen as nn
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11from omegaconf import DictConfig
12
13from evorl.distributed import psum, pmean
14from evorl.distributed.gradients import gradient_update
15from evorl.envs import AutoresetMode, Box, create_env, Space
16from evorl.evaluators import Evaluator
17from evorl.metrics import MetricBase, metric_field
18from evorl.sample_batch import SampleBatch
19from evorl.types import (
20 Action,
21 LossDict,
22 Params,
23 PolicyExtraInfo,
24 PyTreeData,
25 PyTreeDict,
26 State,
27 pytree_field,
28)
29from evorl.utils import running_statistics
30from evorl.utils.jax_utils import (
31 scan_and_mean,
32 tree_stop_gradient,
33 tree_get,
34 right_shift_with_padding,
35)
36from evorl.evaluators import EpisodeCollector
37from evorl.agent import Agent, AgentState
38from evorl.replay_buffers import LAPReplayBuffer
39from evorl.recorders import add_prefix
40from evorl.networks.linear import MLP
41
42from .offpolicy_utils import OffPolicyWorkflowTemplate, skip_replay_buffer_state
43
44logger = logging.getLogger(__name__)
45
46
[docs]
47def avg_l1_norm(x: jax.Array, eps: float = 1e-8) -> jax.Array:
48 """Average L1 Norm used in TD7."""
49 mean_abs = jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=eps)
50 return x / mean_abs
51
52
[docs]
53class TD7Encoder(nn.Module):
54 z_s_dim: int = 256
55 z_sa_dim: int = 256
56 f_layer_sizes: Sequence[int] = (256, 256)
57 g_layer_sizes: Sequence[int] = (256, 256)
58
[docs]
59 def setup(self):
60 self.zs_mlp = MLP(
61 layer_sizes=tuple(self.f_layer_sizes) + (self.z_s_dim,),
62 activation=nn.elu,
63 name="zs_mlp",
64 )
65 self.zsa_mlp = MLP(
66 layer_sizes=tuple(self.g_layer_sizes) + (self.z_sa_dim,),
67 activation=nn.elu,
68 name="zsa_mlp",
69 )
70
[docs]
71 def zs(self, obs: jax.Array) -> jax.Array:
72 z = self.zs_mlp(obs)
73 return avg_l1_norm(z)
74
[docs]
75 def zsa(self, z_s: jax.Array, action: jax.Array) -> jax.Array:
76 z = jnp.concatenate([z_s, action], axis=-1)
77 return self.zsa_mlp(z)
78
79 def __call__(
80 self, obs: jax.Array, action: jax.Array
81 ) -> tuple[jax.Array, jax.Array]:
82 # Utility method to make Flax initialization easier
83 z_s = self.zs(obs)
84 z_sa = self.zsa(z_s, action)
85 return z_s, z_sa
86
87
[docs]
88class TD7Actor(nn.Module):
89 action_size: int
90 z_s_dim: int = 256
91 state_emb_dim: int = 256
92 hidden_layer_sizes: Sequence[int] = (256, 256)
93
94 @nn.compact
95 def __call__(self, obs: jax.Array, z_s: jax.Array) -> jax.Array:
96 a = nn.Dense(self.state_emb_dim, name="l0")(obs)
97 a = avg_l1_norm(a)
98 a = jnp.concatenate([a, z_s], axis=-1)
99
100 a = MLP(
101 layer_sizes=tuple(self.hidden_layer_sizes) + (self.action_size,),
102 activation=nn.relu,
103 name="actor_mlp",
104 )(a)
105
106 return nn.tanh(a)
107
108
[docs]
109class TD7Critic(nn.Module):
110 z_s_dim: int = 256
111 z_sa_dim: int = 256
112 state_action_emb_dim: int = 256
113 hidden_layer_sizes: Sequence[int] = (256, 256)
114
115 @nn.compact
116 def __call__(
117 self, obs: jax.Array, action: jax.Array, z_sa: jax.Array, z_s: jax.Array
118 ) -> jax.Array:
119 sa = jnp.concatenate([obs, action], axis=-1)
120
121 # q1 network
122 q1 = nn.Dense(self.state_action_emb_dim, name="q1_0")(sa)
123 q1 = avg_l1_norm(q1)
124 q1 = jnp.concatenate([q1, z_sa, z_s], axis=-1)
125 q1 = MLP(
126 layer_sizes=tuple(self.hidden_layer_sizes) + (1,),
127 activation=nn.elu,
128 name="q1_mlp",
129 )(q1)
130
131 # q2 network
132 q2 = nn.Dense(self.state_action_emb_dim, name="q2_0")(sa)
133 q2 = avg_l1_norm(q2)
134 q2 = jnp.concatenate([q2, z_sa, z_s], axis=-1)
135 q2 = MLP(
136 layer_sizes=tuple(self.hidden_layer_sizes) + (1,),
137 activation=nn.elu,
138 name="q2_mlp",
139 )(q2)
140
141 return jnp.concatenate([q1, q2], axis=-1)
142
143
[docs]
144class TD7TrainMetric(MetricBase):
145 critic_loss: chex.Array
146 actor_loss: chex.Array
147 encoder_loss: chex.Array
148 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
149
150
[docs]
151class TD7NetworkParams(PyTreeData):
152 actor_params: Params
153 critic_params: Params
154 encoder_params: Params
155 target_actor_params: Params
156 target_critic_params: Params
157 fixed_encoder_params: Params
158 fixed_encoder_target_params: Params
159 checkpoint_actor_params: Params
160 checkpoint_encoder_params: Params
161
162
[docs]
163class TD7Agent(Agent):
164 """The Agent for TD7."""
165
166 critic_network: nn.Module
167 actor_network: nn.Module
168 encoder_network: nn.Module
169 obs_preprocessor: Any = pytree_field(default=None, static=True)
170
171 discount: float = 0.99
172 exploration_epsilon: float = 0.1
173 policy_noise: float = 0.2
174 clip_policy_noise: float = 0.5
175 min_priority: float = 1.0
176
177 @property
178 def normalize_obs(self):
179 return self.obs_preprocessor is not None
180
[docs]
181 def init(
182 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
183 ) -> AgentState:
184 key, q_key, actor_key, enc_key = jax.random.split(key, num=4)
185
186 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
187 dummy_action = action_space.sample(key)[None, ...]
188
189 encoder_params = self.encoder_network.init(enc_key, dummy_obs, dummy_action)
190
191 # Need z_s and z_sa to pass to other networks
192 dummy_z_s, dummy_z_sa = self.encoder_network.apply(
193 encoder_params, dummy_obs, dummy_action
194 )
195
196 critic_params = self.critic_network.init(
197 q_key, dummy_obs, dummy_action, dummy_z_sa, dummy_z_s
198 )
199
200 actor_params = self.actor_network.init(actor_key, dummy_obs, dummy_z_s)
201
202 params_state = TD7NetworkParams(
203 encoder_params=encoder_params,
204 actor_params=actor_params,
205 critic_params=critic_params,
206 fixed_encoder_params=encoder_params,
207 target_actor_params=actor_params,
208 target_critic_params=critic_params,
209 fixed_encoder_target_params=encoder_params,
210 checkpoint_actor_params=actor_params,
211 checkpoint_encoder_params=encoder_params,
212 )
213
214 if self.normalize_obs:
215 obs_preprocessor_state = running_statistics.init_state(
216 tree_get(dummy_obs, 0)
217 )
218 else:
219 obs_preprocessor_state = None
220
221 # Value clipping states and best performances
222 extra_state = PyTreeDict(
223 max_q=jnp.array(jnp.finfo(jnp.float32).min, dtype=jnp.float32),
224 min_q=jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32),
225 max_target=jnp.array(0.0, dtype=jnp.float32),
226 min_target=jnp.array(0.0, dtype=jnp.float32),
227 best_perf=jnp.array(jnp.finfo(jnp.float32).min, dtype=jnp.float32),
228 )
229
230 return AgentState(
231 params=params_state,
232 obs_preprocessor_state=obs_preprocessor_state,
233 extra_state=extra_state,
234 )
235
[docs]
236 def compute_actions(
237 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
238 ) -> tuple[Action, PolicyExtraInfo]:
239 obs = sample_batch.obs
240 if self.normalize_obs:
241 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
242
243 # Uses the fixed_encoder to get the state embedding for the actor
244 z_s = self.encoder_network.apply(
245 agent_state.params.fixed_encoder_params, obs, method="zs"
246 )
247 actions = self.actor_network.apply(agent_state.params.actor_params, obs, z_s)
248
249 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon
250 actions += noise
251 actions = jnp.clip(actions, -1.0, 1.0)
252
253 return actions, PyTreeDict()
254
[docs]
255 def evaluate_actions(
256 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
257 ) -> tuple[Action, PolicyExtraInfo]:
258 obs = sample_batch.obs
259 if self.normalize_obs:
260 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
261
262 # Evaluation uses checkpoint_encoder and checkpoint_actor
263 z_s = self.encoder_network.apply(
264 agent_state.params.checkpoint_encoder_params, obs, method="zs"
265 )
266 actions = self.actor_network.apply(
267 agent_state.params.checkpoint_actor_params, obs, z_s
268 )
269
270 return actions, PyTreeDict()
271
[docs]
272 def encoder_loss(
273 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
274 ) -> LossDict:
275 next_obs = sample_batch.extras.env_extras.ori_obs
276 obs = sample_batch.obs
277 actions = sample_batch.actions
278
279 if self.normalize_obs:
280 next_obs = self.obs_preprocessor(
281 next_obs, agent_state.obs_preprocessor_state
282 )
283 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
284
285 next_z_s = self.encoder_network.apply(
286 agent_state.params.encoder_params, next_obs, method="zs"
287 )
288 next_z_s = jax.lax.stop_gradient(next_z_s)
289
290 z_s = self.encoder_network.apply(
291 agent_state.params.encoder_params, obs, method="zs"
292 )
293 pred_z_sa = self.encoder_network.apply(
294 agent_state.params.encoder_params, z_s, actions, method="zsa"
295 )
296
297 enc_loss = optax.squared_error(pred_z_sa, next_z_s).mean()
298 return PyTreeDict(encoder_loss=enc_loss)
299
[docs]
300 def critic_loss(
301 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
302 ) -> LossDict:
303 next_obs = sample_batch.extras.env_extras.ori_obs
304 obs = sample_batch.obs
305 actions = sample_batch.actions
306
307 if self.normalize_obs:
308 next_obs = self.obs_preprocessor(
309 next_obs, agent_state.obs_preprocessor_state
310 )
311 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
312
313 # Target Q computation
314 fixed_target_z_s = self.encoder_network.apply(
315 agent_state.params.fixed_encoder_target_params, next_obs, method="zs"
316 )
317
318 noise = jnp.clip(
319 jax.random.normal(key, actions.shape) * self.policy_noise,
320 -self.clip_policy_noise,
321 self.clip_policy_noise,
322 )
323 next_actions = self.actor_network.apply(
324 agent_state.params.target_actor_params, next_obs, fixed_target_z_s
325 )
326 next_actions = jnp.clip(next_actions + noise, -1.0, 1.0)
327
328 fixed_target_z_sa = self.encoder_network.apply(
329 agent_state.params.fixed_encoder_target_params,
330 fixed_target_z_s,
331 next_actions,
332 method="zsa",
333 )
334
335 next_qs = self.critic_network.apply(
336 agent_state.params.target_critic_params,
337 next_obs,
338 next_actions,
339 fixed_target_z_sa,
340 fixed_target_z_s,
341 )
342 next_qs_min = next_qs.min(-1)
343
344 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
345
346 # Value clipping from extra_state
347 min_target = agent_state.extra_state.min_target
348 max_target = agent_state.extra_state.max_target
349
350 q_target = sample_batch.rewards + discounts * jnp.clip(
351 next_qs_min, min_target, max_target
352 )
353 q_target = jnp.broadcast_to(q_target[..., None], (*q_target.shape, 2))
354 q_target = jax.lax.stop_gradient(q_target)
355
356 # Current Q computation
357 fixed_z_s = self.encoder_network.apply(
358 agent_state.params.fixed_encoder_params, obs, method="zs"
359 )
360 fixed_z_sa = self.encoder_network.apply(
361 agent_state.params.fixed_encoder_params, fixed_z_s, actions, method="zsa"
362 )
363
364 qs = self.critic_network.apply(
365 agent_state.params.critic_params, obs, actions, fixed_z_sa, fixed_z_s
366 )
367
368 td_error = jnp.abs(qs - q_target)
369
370 # LAP huber loss
371 critic_loss = (
372 jnp.where(
373 td_error < self.min_priority,
374 0.5 * jnp.square(td_error),
375 self.min_priority * td_error,
376 )
377 .sum(-1)
378 .mean()
379 )
380
381 # Update running max/min Q values (using global min/max)
382 batch_q_max = q_target[..., 0].max()
383 batch_q_min = q_target[..., 0].min()
384
385 # Compute priority updates
386 priority = jnp.maximum(td_error.max(axis=-1), self.min_priority)
387
388 return PyTreeDict(
389 critic_loss=critic_loss,
390 q_value=qs.mean(),
391 priority=jax.lax.stop_gradient(priority),
392 batch_q_max=jax.lax.stop_gradient(batch_q_max),
393 batch_q_min=jax.lax.stop_gradient(batch_q_min),
394 )
395
[docs]
396 def actor_loss(
397 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
398 ) -> LossDict:
399 obs = sample_batch.obs
400
401 if self.normalize_obs:
402 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
403
404 fixed_z_s = self.encoder_network.apply(
405 agent_state.params.fixed_encoder_params, obs, method="zs"
406 )
407 actions = self.actor_network.apply(
408 agent_state.params.actor_params, obs, fixed_z_s
409 )
410
411 fixed_z_sa = self.encoder_network.apply(
412 agent_state.params.fixed_encoder_params, fixed_z_s, actions, method="zsa"
413 )
414
415 qs = self.critic_network.apply(
416 agent_state.params.critic_params, obs, actions, fixed_z_sa, fixed_z_s
417 )
418
419 actor_loss = -jnp.mean(qs)
420 return PyTreeDict(actor_loss=actor_loss)
421
422
[docs]
423def make_td7_agent(
424 action_space: Space,
425 z_s_dim: int = 256,
426 z_sa_dim: int = 256,
427 f_layer_sizes: Sequence[int] = (256, 256),
428 g_layer_sizes: Sequence[int] = (256, 256),
429 state_emb_dim: int = 256,
430 state_action_emb_dim: int = 256,
431 critic_hidden_layer_sizes: Sequence[int] = (256, 256),
432 actor_hidden_layer_sizes: Sequence[int] = (256, 256),
433 discount: float = 0.99,
434 exploration_epsilon: float = 0.1,
435 policy_noise: float = 0.2,
436 clip_policy_noise: float = 0.5,
437 min_priority: float = 1.0,
438 normalize_obs: bool = False,
439):
440 assert isinstance(action_space, Box), "Only continue action space is supported."
441
442 action_size = action_space.shape[0]
443
444 encoder_network = TD7Encoder(
445 z_s_dim=z_s_dim,
446 z_sa_dim=z_sa_dim,
447 f_layer_sizes=f_layer_sizes,
448 g_layer_sizes=g_layer_sizes,
449 )
450 critic_network = TD7Critic(
451 z_s_dim=z_s_dim,
452 z_sa_dim=z_sa_dim,
453 state_action_emb_dim=state_action_emb_dim,
454 hidden_layer_sizes=critic_hidden_layer_sizes,
455 )
456 actor_network = TD7Actor(
457 z_s_dim=z_s_dim,
458 state_emb_dim=state_emb_dim,
459 hidden_layer_sizes=actor_hidden_layer_sizes,
460 action_size=action_size,
461 )
462
463 if normalize_obs:
464 obs_preprocessor = running_statistics.normalize
465 else:
466 obs_preprocessor = None
467
468 return TD7Agent(
469 encoder_network=encoder_network,
470 critic_network=critic_network,
471 actor_network=actor_network,
472 obs_preprocessor=obs_preprocessor,
473 discount=discount,
474 exploration_epsilon=exploration_epsilon,
475 policy_noise=policy_noise,
476 clip_policy_noise=clip_policy_noise,
477 min_priority=min_priority,
478 )
479
480
[docs]
481class TD7Workflow(OffPolicyWorkflowTemplate):
[docs]
482 @classmethod
483 def name(cls):
484 return "TD7"
485
486 @classmethod
487 def _build_from_config(cls, config: DictConfig):
488 assert config.rollout_episodes % config.num_envs == 0, (
489 "rollout_episodes must be divisible by num_envs"
490 )
491
492 env = create_env(
493 config.env,
494 episode_length=config.env.max_episode_steps,
495 parallel=config.num_envs,
496 autoreset_mode=AutoresetMode.DISABLED,
497 record_ori_obs=True,
498 )
499
500 agent = make_td7_agent(
501 action_space=env.action_space,
502 z_s_dim=config.agent_network.zs_dim,
503 z_sa_dim=config.agent_network.zsa_dim,
504 f_layer_sizes=config.agent_network.f_layer_sizes,
505 g_layer_sizes=config.agent_network.g_layer_sizes,
506 state_emb_dim=config.agent_network.state_emb_dim,
507 state_action_emb_dim=config.agent_network.state_action_emb_dim,
508 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
509 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
510 discount=config.discount,
511 exploration_epsilon=config.exploration_epsilon,
512 policy_noise=config.policy_noise,
513 clip_policy_noise=config.clip_policy_noise,
514 min_priority=config.min_priority,
515 normalize_obs=config.normalize_obs,
516 )
517
518 if (
519 config.optimizer.grad_clip_norm is not None
520 and config.optimizer.grad_clip_norm > 0
521 ):
522 optimizer = optax.chain(
523 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
524 optax.adam(config.optimizer.lr),
525 )
526 else:
527 optimizer = optax.adam(config.optimizer.lr)
528
529 replay_buffer = LAPReplayBuffer(
530 capacity=config.replay_buffer_capacity,
531 sample_batch_size=config.batch_size,
532 alpha=config.lap_alpha,
533 )
534
535 eval_env = create_env(
536 config.env,
537 episode_length=config.env.max_episode_steps,
538 parallel=config.num_eval_envs,
539 autoreset_mode=AutoresetMode.DISABLED,
540 )
541
542 evaluator = Evaluator(
543 env=eval_env,
544 action_fn=agent.evaluate_actions,
545 max_episode_steps=config.env.max_episode_steps,
546 )
547
548 collector = EpisodeCollector(
549 env=env,
550 action_fn=agent.compute_actions,
551 max_episode_steps=config.env.max_episode_steps,
552 env_extra_fields=("ori_obs", "termination"),
553 )
554
555 workflow = cls(
556 env,
557 agent,
558 optimizer,
559 evaluator,
560 replay_buffer,
561 config,
562 )
563 workflow.collector = collector
564 return workflow
565
566 def _setup_agent_and_optimizer(
567 self, key: chex.PRNGKey
568 ) -> tuple[AgentState, chex.ArrayTree]:
569 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
570
571 opt_state = PyTreeDict(
572 actor=self.optimizer.init(agent_state.params.actor_params),
573 critic=self.optimizer.init(agent_state.params.critic_params),
574 encoder=self.optimizer.init(agent_state.params.encoder_params),
575 )
576 return agent_state, opt_state
577
[docs]
578 def step(self, state: State) -> tuple[MetricBase, State]:
579 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
580
581 # Evaluate via episodes (rollout [T, episodes, ...])
582 eval_metrics, trajectory = self.collector.rollout(
583 state.agent_state, rollout_key, self.config.rollout_episodes
584 )
585
586 trajectory = trajectory.replace(next_obs=None)
587
588 # Mask out padded steps based on `dones` array (since autoreset is OFF)
589 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1))
590 trajectory = trajectory.replace(dones=None)
591
592 def _flatten_fn(x):
593 return x.reshape(-1, *x.shape[2:])
594
595 trajectory = jtu.tree_map(_flatten_fn, trajectory)
596 mask = jtu.tree_map(_flatten_fn, mask)
597
598 trajectory, mask = tree_stop_gradient((trajectory, mask))
599
600 agent_state = state.agent_state
601 if agent_state.obs_preprocessor_state is not None:
602 agent_state = agent_state.replace(
603 obs_preprocessor_state=running_statistics.update(
604 agent_state.obs_preprocessor_state,
605 trajectory.obs,
606 dp_axis_name=self.dp_axis_name,
607 )
608 )
609
610 replay_buffer_state = self.replay_buffer.add(
611 state.replay_buffer_state, trajectory, mask=mask
612 )
613
614 # Gradient update wrappers (can't easily use agent_gradient_update because params are nested in TD7NetworkParams)
615 def encoder_loss_fn(params, agent_state, sample_batch, key):
616 # Evaluate using modified encoder params
617 temp_agent_state = agent_state.replace(
618 params=agent_state.params.replace(encoder_params=params)
619 )
620 loss_dict = self.agent.encoder_loss(temp_agent_state, sample_batch, key)
621 return loss_dict.encoder_loss, loss_dict
622
623 def critic_loss_fn(params, agent_state, sample_batch, key):
624 temp_agent_state = agent_state.replace(
625 params=agent_state.params.replace(critic_params=params)
626 )
627 loss_dict = self.agent.critic_loss(temp_agent_state, sample_batch, key)
628 return loss_dict.critic_loss, loss_dict
629
630 def actor_loss_fn(params, agent_state, sample_batch, key):
631 temp_agent_state = agent_state.replace(
632 params=agent_state.params.replace(actor_params=params)
633 )
634 loss_dict = self.agent.actor_loss(temp_agent_state, sample_batch, key)
635 return loss_dict.actor_loss, loss_dict
636
637 encoder_update_fn = gradient_update(
638 encoder_loss_fn,
639 self.optimizer,
640 dp_axis_name=self.dp_axis_name,
641 has_aux=True,
642 )
643
644 critic_update_fn = gradient_update(
645 critic_loss_fn,
646 self.optimizer,
647 dp_axis_name=self.dp_axis_name,
648 has_aux=True,
649 )
650
651 actor_update_fn = gradient_update(
652 actor_loss_fn,
653 self.optimizer,
654 dp_axis_name=self.dp_axis_name,
655 has_aux=True,
656 )
657
658 def _sample_and_update_fn(carry, t):
659 key, agent_state, opt_state, replay_state, training_steps = carry
660
661 enc_opt_state = opt_state.encoder
662 critic_opt_state = opt_state.critic
663 actor_opt_state = opt_state.actor
664
665 key, enc_key, critic_key, actor_key, rb_key = jax.random.split(key, num=5)
666
667 # Sample from Replay Buffer (yields batch + index tracking state in one pass)
668 sample_batch, _weights, replay_state = self.replay_buffer.sample(
669 replay_state, rb_key
670 )
671
672 # 1. Update Encoder
673 (enc_loss, enc_loss_dict), enc_params, enc_opt_state = encoder_update_fn(
674 enc_opt_state,
675 agent_state.params.encoder_params,
676 agent_state,
677 sample_batch,
678 enc_key,
679 )
680 agent_state = agent_state.replace(
681 params=agent_state.params.replace(encoder_params=enc_params)
682 )
683
684 # 2. Update Critic
685 (critic_loss, critic_loss_dict), critic_params, critic_opt_state = (
686 critic_update_fn(
687 critic_opt_state,
688 agent_state.params.critic_params,
689 agent_state,
690 sample_batch,
691 critic_key,
692 )
693 )
694 agent_state = agent_state.replace(
695 params=agent_state.params.replace(critic_params=critic_params)
696 )
697
698 # LAP Priority updates
699 priority = critic_loss_dict.priority
700 replay_state = self.replay_buffer.update_priority(replay_state, priority)
701
702 # Value clipping min/max tracking updates
703 new_max = jnp.maximum(
704 agent_state.extra_state.max_q, critic_loss_dict.batch_q_max
705 )
706 new_min = jnp.minimum(
707 agent_state.extra_state.min_q, critic_loss_dict.batch_q_min
708 )
709 agent_state = agent_state.replace(
710 extra_state=agent_state.extra_state.replace(
711 max_q=new_max,
712 min_q=new_min,
713 )
714 )
715
716 # 3. Update Actor
717 def _update_actor(carry):
718 agent_state, actor_opt_state = carry
719 (actor_loss, actor_loss_dict), actor_params, actor_opt_state = (
720 actor_update_fn(
721 actor_opt_state,
722 agent_state.params.actor_params,
723 agent_state,
724 sample_batch,
725 actor_key,
726 )
727 )
728 agent_state = agent_state.replace(
729 params=agent_state.params.replace(actor_params=actor_params)
730 )
731 return agent_state, actor_opt_state, actor_loss, actor_loss_dict
732
733 def _skip_actor(carry):
734 agent_state, actor_opt_state = carry
735 return (
736 agent_state,
737 actor_opt_state,
738 jnp.array(0.0),
739 PyTreeDict(actor_loss=jnp.array(0.0)),
740 )
741
742 agent_state, actor_opt_state, actor_loss, actor_loss_dict = jax.lax.cond(
743 (training_steps + 1) % self.config.policy_freq == 0,
744 _update_actor,
745 _skip_actor,
746 (agent_state, actor_opt_state),
747 )
748
749 # 4. Hard target updates
750 def _hard_target_updates(carry):
751 agent_state, replay_state = carry
752 agent_state = agent_state.replace(
753 params=agent_state.params.replace(
754 target_actor_params=agent_state.params.actor_params,
755 target_critic_params=agent_state.params.critic_params,
756 fixed_encoder_target_params=agent_state.params.fixed_encoder_params,
757 fixed_encoder_params=agent_state.params.encoder_params,
758 ),
759 extra_state=agent_state.extra_state.replace(
760 max_target=agent_state.extra_state.max_q,
761 min_target=agent_state.extra_state.min_q,
762 ),
763 )
764 replay_state = self.replay_buffer.reset_max_priority(replay_state)
765 return agent_state, replay_state
766
767 def _skip_updates(carry):
768 return carry
769
770 agent_state, replay_state = jax.lax.cond(
771 (training_steps + 1) % self.config.target_update_rate == 0,
772 _hard_target_updates,
773 _skip_updates,
774 (agent_state, replay_state),
775 )
776
777 opt_state = opt_state.replace(
778 encoder=enc_opt_state, actor=actor_opt_state, critic=critic_opt_state
779 )
780
781 # We use zero for dummy actor losses if we didn't update it to avoid NaN downstream
782 return (
783 (key, agent_state, opt_state, replay_state, training_steps + 1),
784 (
785 enc_loss,
786 critic_loss,
787 actor_loss,
788 enc_loss_dict,
789 critic_loss_dict,
790 actor_loss_dict,
791 ),
792 )
793
794 # Retrieve global training steps from state metrics iterations
795 global_steps = state.metrics.iterations * self.config.num_updates_per_iter
796
797 # Need to cast loop dummy variable to integer
798 iters = jnp.arange(self.config.num_updates_per_iter, dtype=jnp.int32)
799
800 (
801 (_, agent_state, opt_state, replay_buffer_state, _),
802 (
803 encoder_loss,
804 critic_loss,
805 actor_loss,
806 enc_loss_dict,
807 critic_loss_dict,
808 actor_loss_dict,
809 ),
810 ) = scan_and_mean(
811 _sample_and_update_fn,
812 (
813 learn_key,
814 agent_state,
815 state.opt_state,
816 replay_buffer_state,
817 global_steps,
818 ),
819 iters,
820 length=self.config.num_updates_per_iter,
821 )
822
823 # Episodic Checkpointing evaluate & replace
824 if self.config.checkpoint_metric == "mean":
825 perf = jnp.mean(eval_metrics.episode_returns)
826 elif self.config.checkpoint_metric == "min":
827 perf = jnp.min(eval_metrics.episode_returns)
828 elif self.config.checkpoint_metric == "max":
829 perf = jnp.max(eval_metrics.episode_returns)
830 else:
831 raise ValueError(
832 f"Unsupported checkpoint metric: {self.config.checkpoint_metric}. "
833 "Must be one of 'min', 'max', or 'mean'."
834 )
835
836 def _update_checkpoint(ag_state):
837 return ag_state.replace(
838 params=ag_state.params.replace(
839 checkpoint_actor_params=ag_state.params.actor_params,
840 checkpoint_encoder_params=ag_state.params.fixed_encoder_params,
841 ),
842 extra_state=ag_state.extra_state.replace(best_perf=perf),
843 )
844
845 agent_state = jax.lax.cond(
846 perf >= agent_state.extra_state.best_perf,
847 _update_checkpoint,
848 lambda ag_state: ag_state,
849 agent_state,
850 )
851
852 # actor loss would be divided by policy_freq effectively (thanks to zeros)
853 # So multiply back by policy_freq to get the real mean
854 actor_loss = actor_loss * self.config.policy_freq
855
856 train_metrics = TD7TrainMetric(
857 encoder_loss=encoder_loss,
858 actor_loss=actor_loss,
859 critic_loss=critic_loss,
860 raw_loss_dict=PyTreeDict(
861 {**enc_loss_dict, **critic_loss_dict, **actor_loss_dict}
862 ),
863 ).all_reduce(dp_axis_name=self.dp_axis_name)
864
865 sampled_timesteps = jnp.uint32(eval_metrics.episode_lengths.sum())
866 sampled_timesteps = psum(sampled_timesteps, axis_name=self.dp_axis_name)
867
868 sampled_epsiodes = psum(
869 jnp.uint32(self.config.rollout_episodes), axis_name=self.dp_axis_name
870 )
871
872 workflow_metrics = state.metrics.replace(
873 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
874 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
875 iterations=state.metrics.iterations + 1,
876 ).all_reduce(dp_axis_name=self.dp_axis_name)
877
878 # state.env_state is irrelevant here since episodecollector handles reset internally.
879 # we can just pass original env_state unmodified since disabled autoreset env ignores it generally across steps (or we just maintain step 0).
880 return train_metrics, state.replace(
881 key=key,
882 metrics=workflow_metrics,
883 agent_state=agent_state,
884 replay_buffer_state=replay_buffer_state,
885 opt_state=opt_state,
886 )
887
[docs]
888 def learn(self, state: State) -> State:
889 num_devices = jax.device_count()
890 one_step_episodes = self.config.rollout_episodes * num_devices
891 sampled_episodes = state.metrics.sampled_episodes.tolist()
892 num_iters = math.ceil(
893 (self.config.total_episodes - sampled_episodes)
894 / (one_step_episodes * self.config.fold_iters)
895 )
896 start_iteration = state.metrics.iterations.tolist()
897 final_iteration = num_iters + start_iteration
898
899 for i in range(start_iteration, final_iteration):
900 iterations = i + 1
901 train_metrics, state = self._multi_steps(state)
902 workflow_metrics = state.metrics
903
904 self.recorder.write(train_metrics.to_local_dict(), iterations)
905 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
906
907 if (
908 iterations % self.config.eval_interval == 0
909 or iterations == final_iteration
910 ):
911 eval_metrics, state = self.evaluate(state)
912 self.recorder.write(
913 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
914 )
915
916 saved_state = state
917 if not self.config.save_replay_buffer:
918 saved_state = skip_replay_buffer_state(saved_state)
919 self.checkpoint_manager.save(
920 iterations, saved_state, force=iterations == final_iteration
921 )
922
923 return state