1import logging
2from typing import Any
3
4import chex
5import flax.linen as nn
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10from omegaconf import DictConfig
11
12from evorl.distributed import psum, pmean
13from evorl.distributed.gradients import agent_gradient_update
14from evorl.envs import AutoresetMode, Box, create_env, Space
15from evorl.evaluators import Evaluator
16from evorl.metrics import MetricBase, metric_field
17from evorl.networks import make_policy_network, make_q_network
18from evorl.rollout import rollout
19from evorl.replay_buffers import ReplayBuffer
20from evorl.sample_batch import SampleBatch
21from evorl.types import (
22 Action,
23 LossDict,
24 Params,
25 PolicyExtraInfo,
26 PyTreeData,
27 PyTreeDict,
28 State,
29 pytree_field,
30)
31from evorl.utils import running_statistics
32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
34from evorl.agent import Agent, AgentState
35
36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
37
38logger = logging.getLogger(__name__)
39
40
[docs]
41class TD3TrainMetric(MetricBase):
42 critic_loss: chex.Array
43 actor_loss: chex.Array
44 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
45
46
[docs]
47class TD3NetworkParams(PyTreeData):
48 """Contains training state for the learner."""
49
50 actor_params: Params
51 critic_params: Params
52 target_actor_params: Params
53 target_critic_params: Params
54
55
[docs]
56class TD3Agent(Agent):
57 """The Agnet for TD3."""
58
59 critic_network: nn.Module
60 actor_network: nn.Module
61 obs_preprocessor: Any = pytree_field(default=None, static=True)
62
63 discount: float = 0.99
64 exploration_epsilon: float = 0.5
65 policy_noise: float = 0.2
66 clip_policy_noise: float = 0.5
67 critics_in_actor_loss: str = "first" # or "min"
68
69 @property
70 def normalize_obs(self):
71 return self.obs_preprocessor is not None
72
[docs]
73 def init(
74 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
75 ) -> AgentState:
76 key, q_key, actor_key = jax.random.split(key, num=3)
77
78 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
79 dummy_action = action_space.sample(key)[None, ...]
80
81 critic_params = self.critic_network.init(q_key, dummy_obs, dummy_action)
82 target_critic_params = critic_params
83
84 actor_params = self.actor_network.init(actor_key, dummy_obs)
85 target_actor_params = actor_params
86
87 params_state = TD3NetworkParams(
88 critic_params=critic_params,
89 actor_params=actor_params,
90 target_critic_params=target_critic_params,
91 target_actor_params=target_actor_params,
92 )
93
94 if self.normalize_obs:
95 # Note: statistics are broadcasted to [T*B]
96 obs_preprocessor_state = running_statistics.init_state(
97 tree_get(dummy_obs, 0)
98 )
99 else:
100 obs_preprocessor_state = None
101
102 return AgentState(
103 params=params_state,
104 obs_preprocessor_state=obs_preprocessor_state,
105 )
106
[docs]
107 def compute_actions(
108 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
109 ) -> tuple[Action, PolicyExtraInfo]:
110 # sample_barch: [#env, ...]
111
112 obs = sample_batch.obs
113 if self.normalize_obs:
114 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
115
116 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
117 # add random noise
118 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon
119 actions += noise
120 actions = jnp.clip(actions, -1.0, 1.0)
121
122 return actions, PyTreeDict()
123
[docs]
124 def evaluate_actions(
125 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
126 ) -> tuple[Action, PolicyExtraInfo]:
127 # sample_barch: [#env, ...]
128
129 obs = sample_batch.obs
130 if self.normalize_obs:
131 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
132
133 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
134
135 return actions, PyTreeDict()
136
[docs]
137 def critic_loss(
138 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
139 ) -> LossDict:
140 """Critic loss in TD3.
141
142 Args:
143 sample_barch: [B, ...]
144
145 Return: LossDict[
146 actor_loss
147 critic_loss
148 actor_entropy_loss
149 ]
150 """
151 next_obs = sample_batch.extras.env_extras.ori_obs
152 obs = sample_batch.obs
153 actions = sample_batch.actions
154
155 if self.normalize_obs:
156 next_obs = self.obs_preprocessor(
157 next_obs, agent_state.obs_preprocessor_state
158 )
159 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
160
161 next_actions = self.actor_network.apply(
162 agent_state.params.target_actor_params, next_obs
163 )
164 next_actions += jnp.clip(
165 jax.random.normal(key, actions.shape) * self.policy_noise,
166 -self.clip_policy_noise,
167 self.clip_policy_noise,
168 )
169 # Note: when calculating the critic loss, we also clip the actions to the action space
170 next_actions = jnp.clip(next_actions, -1.0, 1.0)
171
172 # [B, num_critics]
173 next_qs = self.critic_network.apply(
174 agent_state.params.target_critic_params, next_obs, next_actions
175 )
176 next_qs_min = next_qs.min(-1)
177
178 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
179
180 qs_target = sample_batch.rewards + discounts * next_qs_min
181 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2))
182 qs_target = jax.lax.stop_gradient(qs_target)
183
184 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions)
185
186 # q_loss = optax.huber_loss(qs, qs_target).sum(-1).mean()
187 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean()
188
189 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
190
[docs]
191 def actor_loss(
192 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
193 ) -> LossDict:
194 """Actor loss in TD3.
195
196 Args:
197 sample_barch: [B, ...]
198
199 Return: LossDict[
200 actor_loss
201 critic_loss
202 actor_entropy_loss
203 ]
204 """
205 obs = sample_batch.obs
206
207 if self.normalize_obs:
208 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
209
210 # [T*B, A]
211 # Note: when calculating the actor loss, we don't clip the actions to the action space, following the impl of SB3 and CleanRL
212 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
213
214 # TODO: handle redundant computation
215 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions)
216
217 if self.critics_in_actor_loss == "first":
218 actor_loss = -jnp.mean(qs[..., 0])
219 elif self.critics_in_actor_loss == "min":
220 # using min_Q, like SAC
221 actor_loss = -jnp.mean(qs.min(-1))
222 else:
223 raise ValueError(
224 f"Invalid value for critics_in_actor_loss: {self.critics_in_actor_loss}, should be 'first' or 'mean'"
225 )
226
227 return PyTreeDict(actor_loss=actor_loss)
228
229
[docs]
230def make_mlp_td3_agent(
231 action_space: Space,
232 norm_layer_type: str = "none",
233 num_critics: int = 2,
234 critic_hidden_layer_sizes: tuple[int] = (256, 256),
235 actor_hidden_layer_sizes: tuple[int] = (256, 256),
236 discount: float = 0.99,
237 exploration_epsilon: float = 0.5,
238 policy_noise: float = 0.2,
239 clip_policy_noise: float = 0.5,
240 critics_in_actor_loss: str = "first", # or "min"
241 normalize_obs: bool = False,
242 policy_obs_key: str = "",
243 value_obs_key: str = "",
244):
245 assert isinstance(action_space, Box), "Only continue action space is supported."
246
247 action_size = action_space.shape[0]
248
249 critic_network = make_q_network(
250 n_stack=num_critics,
251 hidden_layer_sizes=critic_hidden_layer_sizes,
252 norm_layer_type=norm_layer_type,
253 obs_key=value_obs_key,
254 )
255 actor_network = make_policy_network(
256 action_size=action_size,
257 hidden_layer_sizes=actor_hidden_layer_sizes,
258 activation_final=nn.tanh,
259 norm_layer_type=norm_layer_type,
260 obs_key=policy_obs_key,
261 )
262
263 if normalize_obs:
264 obs_preprocessor = running_statistics.normalize
265 else:
266 obs_preprocessor = None
267
268 return TD3Agent(
269 critic_network=critic_network,
270 actor_network=actor_network,
271 obs_preprocessor=obs_preprocessor,
272 discount=discount,
273 exploration_epsilon=exploration_epsilon,
274 policy_noise=policy_noise,
275 clip_policy_noise=clip_policy_noise,
276 critics_in_actor_loss=critics_in_actor_loss,
277 )
278
279
[docs]
280class TD3Workflow(OffPolicyWorkflowTemplate):
[docs]
281 @classmethod
282 def name(cls):
283 return "TD3"
284
285 @classmethod
286 def _build_from_config(cls, config: DictConfig):
287 env = create_env(
288 config.env,
289 episode_length=config.env.max_episode_steps,
290 parallel=config.num_envs,
291 autoreset_mode=AutoresetMode.NORMAL,
292 record_ori_obs=True,
293 )
294
295 agent = make_mlp_td3_agent(
296 action_space=env.action_space,
297 norm_layer_type=config.agent_network.norm_layer_type,
298 num_critics=config.agent_network.num_critics,
299 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
300 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
301 discount=config.discount,
302 exploration_epsilon=config.exploration_epsilon,
303 policy_noise=config.policy_noise,
304 clip_policy_noise=config.clip_policy_noise,
305 critics_in_actor_loss=config.critics_in_actor_loss,
306 normalize_obs=config.normalize_obs,
307 policy_obs_key=config.agent_network.policy_obs_key,
308 value_obs_key=config.agent_network.value_obs_key,
309 )
310
311 # one optimizer, two opt_states (in setup function) for both actor and critic
312 if (
313 config.optimizer.grad_clip_norm is not None
314 and config.optimizer.grad_clip_norm > 0
315 ):
316 optimizer = optax.chain(
317 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
318 optax.adam(config.optimizer.lr),
319 )
320 else:
321 optimizer = optax.adam(config.optimizer.lr)
322
323 replay_buffer = ReplayBuffer(
324 capacity=config.replay_buffer_capacity,
325 min_sample_timesteps=max(
326 config.batch_size, config.learning_start_timesteps
327 ),
328 sample_batch_size=config.batch_size,
329 )
330
331 eval_env = create_env(
332 config.env,
333 episode_length=config.env.max_episode_steps,
334 parallel=config.num_eval_envs,
335 autoreset_mode=AutoresetMode.DISABLED,
336 )
337
338 evaluator = Evaluator(
339 env=eval_env,
340 action_fn=agent.evaluate_actions,
341 max_episode_steps=config.env.max_episode_steps,
342 )
343
344 return cls(
345 env,
346 agent,
347 optimizer,
348 evaluator,
349 replay_buffer,
350 config,
351 )
352
353 def _setup_agent_and_optimizer(
354 self, key: chex.PRNGKey
355 ) -> tuple[AgentState, chex.ArrayTree]:
356 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
357 opt_state = PyTreeDict(
358 actor=self.optimizer.init(agent_state.params.actor_params),
359 critic=self.optimizer.init(agent_state.params.critic_params),
360 )
361 return agent_state, opt_state
362
[docs]
363 def step(self, state: State) -> tuple[MetricBase, State]:
364 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
365
366 # the trajectory [T, B, ...]
367 trajectory, env_state = rollout(
368 env_fn=self.env.step,
369 action_fn=self.agent.compute_actions,
370 env_state=state.env_state,
371 agent_state=state.agent_state,
372 key=rollout_key,
373 rollout_length=self.config.rollout_length,
374 env_extra_fields=("ori_obs", "termination"),
375 )
376
377 trajectory_dones = trajectory.dones
378 trajectory = clean_trajectory(trajectory)
379 trajectory = flatten_rollout_trajectory(trajectory)
380 trajectory = tree_stop_gradient(trajectory)
381
382 agent_state = state.agent_state
383 if agent_state.obs_preprocessor_state is not None:
384 agent_state = agent_state.replace(
385 obs_preprocessor_state=running_statistics.update(
386 agent_state.obs_preprocessor_state,
387 trajectory.obs,
388 dp_axis_name=self.dp_axis_name,
389 )
390 )
391
392 replay_buffer_state = self.replay_buffer.add(
393 state.replay_buffer_state, trajectory
394 )
395
396 def critic_loss_fn(agent_state, sample_batch, key):
397 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
398
399 loss = loss_dict.critic_loss
400 return loss, loss_dict
401
402 def actor_loss_fn(agent_state, sample_batch, key):
403 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
404
405 loss = loss_dict.actor_loss
406 return loss, loss_dict
407
408 critic_update_fn = agent_gradient_update(
409 critic_loss_fn,
410 self.optimizer,
411 dp_axis_name=self.dp_axis_name,
412 has_aux=True,
413 attach_fn=lambda agent_state, critic_params: agent_state.replace(
414 params=agent_state.params.replace(critic_params=critic_params)
415 ),
416 detach_fn=lambda agent_state: agent_state.params.critic_params,
417 )
418
419 actor_update_fn = agent_gradient_update(
420 actor_loss_fn,
421 self.optimizer,
422 dp_axis_name=self.dp_axis_name,
423 has_aux=True,
424 attach_fn=lambda agent_state, actor_params: agent_state.replace(
425 params=agent_state.params.replace(actor_params=actor_params)
426 ),
427 detach_fn=lambda agent_state: agent_state.params.actor_params,
428 )
429
430 def _sample_and_update_fn(carry, unused_t):
431 key, agent_state, opt_state = carry
432
433 critic_opt_state = opt_state.critic
434 actor_opt_state = opt_state.actor
435
436 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4)
437
438 if self.config.actor_update_interval - 1 > 0:
439
440 def _sample_and_update_critic_fn(carry, unused_t):
441 key, agent_state, critic_opt_state = carry
442
443 key, rb_key, critic_key = jax.random.split(key, num=3)
444 # it's safe to use read-only replay_buffer_state here.
445 sample_batch = self.replay_buffer.sample(
446 replay_buffer_state, rb_key
447 )
448
449 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
450 critic_update_fn(
451 critic_opt_state, agent_state, sample_batch, critic_key
452 )
453 )
454
455 return (key, agent_state, critic_opt_state), None
456
457 key, critic_multiple_update_key = jax.random.split(key)
458
459 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
460 _sample_and_update_critic_fn,
461 (critic_multiple_update_key, agent_state, critic_opt_state),
462 (),
463 length=self.config.actor_update_interval - 1,
464 )
465
466 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
467
468 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
469 critic_update_fn(
470 critic_opt_state, agent_state, sample_batch, critic_key
471 )
472 )
473
474 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
475 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key)
476 )
477
478 target_actor_params = soft_target_update(
479 agent_state.params.target_actor_params,
480 agent_state.params.actor_params,
481 self.config.tau,
482 )
483 target_critic_params = soft_target_update(
484 agent_state.params.target_critic_params,
485 agent_state.params.critic_params,
486 self.config.tau,
487 )
488 agent_state = agent_state.replace(
489 params=agent_state.params.replace(
490 target_actor_params=target_actor_params,
491 target_critic_params=target_critic_params,
492 )
493 )
494
495 opt_state = opt_state.replace(
496 actor=actor_opt_state, critic=critic_opt_state
497 )
498
499 return (
500 (key, agent_state, opt_state),
501 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
502 )
503
504 (
505 (_, agent_state, opt_state),
506 (
507 critic_loss,
508 actor_loss,
509 critic_loss_dict,
510 actor_loss_dict,
511 ),
512 ) = scan_and_mean(
513 _sample_and_update_fn,
514 (learn_key, agent_state, state.opt_state),
515 (),
516 length=self.config.num_updates_per_iter,
517 )
518
519 train_metrics = TD3TrainMetric(
520 actor_loss=actor_loss,
521 critic_loss=critic_loss,
522 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
523 ).all_reduce(dp_axis_name=self.dp_axis_name)
524
525 # calculate the number of timestep
526 sampled_timesteps = psum(
527 jnp.uint32(self.config.rollout_length * self.config.num_envs),
528 axis_name=self.dp_axis_name,
529 )
530
531 sampled_epsiodes = psum(
532 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
533 )
534
535 # iterations is the number of updates of the agent
536 workflow_metrics = state.metrics.replace(
537 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
538 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
539 iterations=state.metrics.iterations + 1,
540 ).all_reduce(dp_axis_name=self.dp_axis_name)
541
542 return train_metrics, state.replace(
543 key=key,
544 metrics=workflow_metrics,
545 agent_state=agent_state,
546 env_state=env_state,
547 replay_buffer_state=replay_buffer_state,
548 opt_state=opt_state,
549 )