1import logging
2import math
3from typing import Any
4
5import chex
6import flax.linen as nn
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11from omegaconf import DictConfig
12
13from evorl.replay_buffers import ReplayBuffer
14from evorl.distributed import psum
15from evorl.distributed.gradients import gradient_update
16from evorl.envs import AutoresetMode, Box, create_env, Space
17from evorl.evaluators import Evaluator
18from evorl.metrics import MetricBase
19from evorl.rollout import rollout
20from evorl.networks import make_policy_network, make_q_network
21from evorl.sample_batch import SampleBatch
22from evorl.types import (
23 Action,
24 Params,
25 PyTreeData,
26 PyTreeDict,
27 PolicyExtraInfo,
28 State,
29 pytree_field,
30)
31from evorl.utils import running_statistics
32from evorl.utils.jax_utils import tree_stop_gradient, tree_get
33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
34from evorl.agent import AgentState, Agent
35from evorl.recorders import add_prefix
36
37from ..offpolicy_utils import (
38 OffPolicyWorkflowTemplate,
39 clean_trajectory,
40 skip_replay_buffer_state,
41)
42
43
44logger = logging.getLogger(__name__)
45
46MISSING_LOSS = -1e10
47
48
[docs]
49class TD3Agent(Agent):
50 """The Agnet for TD3."""
51
52 critic_network: nn.Module
53 actor_network: nn.Module
54 obs_preprocessor: Any = pytree_field(default=None, static=True)
55
56 discount: float = 0.99
57 exploration_epsilon: float = 0.5
58 policy_noise: float = 0.2
59 clip_policy_noise: float = 0.5
60 critics_in_actor_loss: str = "first" # or "min"
61
62 @property
63 def normalize_obs(self):
64 return self.obs_preprocessor is not None
65
[docs]
66 def init(
67 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
68 ) -> AgentState:
69 key, q_key1, q_key2, actor_key = jax.random.split(key, num=4)
70
71 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
72 dummy_action = action_space.sample(key)[None, ...]
73
74 critic1_params = self.critic_network.init(q_key1, dummy_obs, dummy_action)
75 target_critic1_params = critic1_params
76 critic2_params = self.critic_network.init(q_key2, dummy_obs, dummy_action)
77 target_critic2_params = critic2_params
78
79 actor_params = self.actor_network.init(actor_key, dummy_obs)
80 target_actor_params = actor_params
81
82 params_state = TD3NetworkParams(
83 actor_params=actor_params,
84 target_actor_params=target_actor_params,
85 critic1_params=critic1_params,
86 target_critic1_params=target_critic1_params,
87 critic2_params=critic2_params,
88 target_critic2_params=target_critic2_params,
89 )
90
91 if self.normalize_obs:
92 # Note: statistics are broadcasted to [T*B]
93 obs_preprocessor_state = running_statistics.init_state(
94 tree_get(dummy_obs, 0)
95 )
96 else:
97 obs_preprocessor_state = None
98
99 return AgentState(
100 params=params_state,
101 obs_preprocessor_state=obs_preprocessor_state,
102 )
103
[docs]
104 def compute_actions(
105 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
106 ) -> tuple[Action, PolicyExtraInfo]:
107 obs = sample_batch.obs
108 if self.normalize_obs:
109 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
110
111 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
112 # add random noise
113 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon
114 actions += noise
115 actions = jnp.clip(actions, -1.0, 1.0)
116
117 return actions, PyTreeDict()
118
[docs]
119 def evaluate_actions(
120 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
121 ) -> tuple[Action, PolicyExtraInfo]:
122 obs = sample_batch.obs
123 if self.normalize_obs:
124 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
125
126 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
127
128 return actions, PyTreeDict()
129
130
[docs]
131def make_mlp_td3_agent(
132 action_space: Space,
133 norm_layer_type: str = "none",
134 critic_hidden_layer_sizes: tuple[int] = (256, 256),
135 actor_hidden_layer_sizes: tuple[int] = (256, 256),
136 discount: float = 0.99,
137 exploration_epsilon: float = 0.5,
138 policy_noise: float = 0.2,
139 clip_policy_noise: float = 0.5,
140 critics_in_actor_loss: str = "first", # or "min"
141 normalize_obs: bool = False,
142):
143 assert isinstance(action_space, Box), "Only continue action space is supported."
144
145 action_size = action_space.shape[0]
146
147 critic_network = make_q_network(
148 n_stack=1,
149 hidden_layer_sizes=critic_hidden_layer_sizes,
150 norm_layer_type=norm_layer_type,
151 )
152 actor_network = make_policy_network(
153 action_size=action_size,
154 hidden_layer_sizes=actor_hidden_layer_sizes,
155 activation_final=nn.tanh,
156 norm_layer_type=norm_layer_type,
157 )
158
159 if normalize_obs:
160 obs_preprocessor = running_statistics.normalize
161 else:
162 obs_preprocessor = None
163
164 return TD3Agent(
165 critic_network=critic_network,
166 actor_network=actor_network,
167 obs_preprocessor=obs_preprocessor,
168 discount=discount,
169 exploration_epsilon=exploration_epsilon,
170 policy_noise=policy_noise,
171 clip_policy_noise=clip_policy_noise,
172 critics_in_actor_loss=critics_in_actor_loss,
173 )
174
175
[docs]
176class TD3TrainMetric(MetricBase):
177 actor_loss: chex.Array
178 critic1_loss: chex.Array
179 critic2_loss: chex.Array
180 q1: chex.Array
181 q2: chex.Array
182
183
[docs]
184class TD3NetworkParams(PyTreeData):
185 """Contains training state for the learner."""
186
187 actor_params: Params
188 critic1_params: Params
189 critic2_params: Params
190 target_actor_params: Params
191 target_critic1_params: Params
192 target_critic2_params: Params
193
194
[docs]
195class TD3V3Workflow(OffPolicyWorkflowTemplate):
196 """The similar impl of TD3 in SB3 and CleanRL."""
197
[docs]
198 @classmethod
199 def name(cls):
200 return "TD3-V3"
201
202 @classmethod
203 def _build_from_config(cls, config: DictConfig):
204 env = create_env(
205 config.env,
206 episode_length=config.env.max_episode_steps,
207 parallel=config.num_envs,
208 autoreset_mode=AutoresetMode.NORMAL,
209 record_ori_obs=True,
210 )
211
212 assert isinstance(env.action_space, Box), (
213 "Only continue action space is supported."
214 )
215
216 agent = make_mlp_td3_agent(
217 action_space=env.action_space,
218 norm_layer_type=config.agent_network.norm_layer_type,
219 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
220 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
221 discount=config.discount,
222 exploration_epsilon=config.exploration_epsilon,
223 policy_noise=config.policy_noise,
224 clip_policy_noise=config.clip_policy_noise,
225 critics_in_actor_loss=config.critics_in_actor_loss,
226 normalize_obs=config.normalize_obs,
227 )
228
229 # one optimizer, two opt_states (in setup function) for both actor and critic
230 if (
231 config.optimizer.grad_clip_norm is not None
232 and config.optimizer.grad_clip_norm > 0
233 ):
234 optimizer = optax.chain(
235 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
236 optax.adam(config.optimizer.lr),
237 )
238 else:
239 optimizer = optax.adam(config.optimizer.lr)
240
241 replay_buffer = ReplayBuffer(
242 capacity=config.replay_buffer_capacity,
243 min_sample_timesteps=max(
244 config.batch_size, config.learning_start_timesteps
245 ),
246 sample_batch_size=config.batch_size,
247 )
248
249 eval_env = create_env(
250 config.env,
251 episode_length=config.env.max_episode_steps,
252 parallel=config.num_eval_envs,
253 autoreset_mode=AutoresetMode.DISABLED,
254 )
255
256 evaluator = Evaluator(
257 env=eval_env,
258 action_fn=agent.evaluate_actions,
259 max_episode_steps=config.env.max_episode_steps,
260 )
261
262 return cls(
263 env,
264 agent,
265 optimizer,
266 evaluator,
267 replay_buffer,
268 config,
269 )
270
271 def _setup_agent_and_optimizer(
272 self, key: chex.PRNGKey
273 ) -> tuple[AgentState, chex.ArrayTree]:
274 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
275 opt_state = PyTreeDict(
276 actor=self.optimizer.init(agent_state.params.actor_params),
277 critic1=self.optimizer.init(agent_state.params.critic1_params),
278 critic2=self.optimizer.init(agent_state.params.critic2_params),
279 )
280 return agent_state, opt_state
281
[docs]
282 def step(self, state: State) -> tuple[MetricBase, State]:
283 iterations = state.metrics.iterations + 1
284 key, rollout_key, critic_key, actor_key, rb_key = jax.random.split(
285 state.key, num=5
286 )
287
288 # the trajectory [T, B, ...]
289 trajectory, env_state = rollout(
290 env_fn=self.env.step,
291 action_fn=self.agent.compute_actions,
292 env_state=state.env_state,
293 agent_state=state.agent_state,
294 key=rollout_key,
295 rollout_length=self.config.rollout_length,
296 env_extra_fields=("ori_obs", "termination"),
297 )
298
299 trajectory = clean_trajectory(trajectory)
300 trajectory = flatten_rollout_trajectory(trajectory)
301 trajectory = tree_stop_gradient(trajectory)
302
303 agent_state = state.agent_state
304 opt_state = state.opt_state
305
306 if agent_state.obs_preprocessor_state is not None:
307 agent_state = agent_state.replace(
308 obs_preprocessor_state=running_statistics.update(
309 agent_state.obs_preprocessor_state,
310 trajectory.obs,
311 dp_axis_name=self.dp_axis_name,
312 )
313 )
314
315 replay_buffer_state = self.replay_buffer.add(
316 state.replay_buffer_state, trajectory
317 )
318
319 def _update_critic_fn(agent_state, opt_state, sample_batch, key):
320 critic1_opt_state = opt_state.critic1
321 critic2_opt_state = opt_state.critic2
322
323 agent = self.agent
324
325 next_obs = sample_batch.extras.env_extras.ori_obs
326 obs = sample_batch.obs
327 actions = sample_batch.actions
328
329 if agent.normalize_obs:
330 next_obs = agent.obs_preprocessor(
331 next_obs, agent_state.obs_preprocessor_state
332 )
333 obs = agent.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
334
335 actions_next = agent.actor_network.apply(
336 agent_state.params.target_actor_params, next_obs
337 )
338 actions_next += jnp.clip(
339 jax.random.normal(key, actions.shape) * agent.policy_noise,
340 -agent.clip_policy_noise,
341 agent.clip_policy_noise,
342 )
343 # Note: when calculating the critic loss, we also clip the actions to the action space
344 actions_next = jnp.clip(actions_next, -1.0, 1.0)
345
346 # [B]
347 qs1_next = agent.critic_network.apply(
348 agent_state.params.target_critic1_params, next_obs, actions_next
349 )
350 qs2_next = agent.critic_network.apply(
351 agent_state.params.target_critic2_params, next_obs, actions_next
352 )
353 qs_next_min = jnp.minimum(qs1_next, qs2_next)
354
355 discounts = agent.discount * (
356 1 - sample_batch.extras.env_extras.termination
357 )
358
359 qs_target = sample_batch.rewards + discounts * qs_next_min
360 qs_target = jax.lax.stop_gradient(qs_target)
361
362 # q_loss = optax.huber_loss(qs, qs_target).sum(-1).mean()
363 def _critic_loss(params):
364 qs = agent.critic_network.apply(params, obs, actions)
365 loss = optax.squared_error(qs, qs_target).mean()
366 return loss, qs.mean()
367
368 critic_update_fn = gradient_update(
369 _critic_loss,
370 self.optimizer,
371 dp_axis_name=self.dp_axis_name,
372 has_aux=True,
373 )
374
375 (critic1_loss, q1), ciritc1_params, critic1_opt_state = critic_update_fn(
376 critic1_opt_state, agent_state.params.critic1_params
377 )
378 (critic2_loss, q2), ciritc2_params, critic2_opt_state = critic_update_fn(
379 critic2_opt_state, agent_state.params.critic2_params
380 )
381
382 agent_state = agent_state.replace(
383 params=agent_state.params.replace(
384 critic1_params=ciritc1_params, critic2_params=ciritc2_params
385 )
386 )
387
388 opt_state = opt_state.replace(
389 critic1=critic1_opt_state, critic2=critic2_opt_state
390 )
391
392 train_info = PyTreeDict(
393 critic1_loss=critic1_loss,
394 critic2_loss=critic2_loss,
395 q1=q1,
396 q2=q2,
397 )
398
399 return (
400 train_info,
401 agent_state,
402 opt_state,
403 )
404
405 def _update_actor_fn(agent_state, opt_state, sample_batch, key):
406 actor_opt_state = opt_state.actor
407 agent = self.agent
408 obs = sample_batch.obs
409
410 if agent.normalize_obs:
411 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
412
413 # ==== actor update ====
414 def _actor_loss(params):
415 actions = agent.actor_network.apply(params, obs)
416 loss = -agent.critic_network.apply(
417 agent_state.params.critic1_params, obs, actions
418 ).mean()
419 return loss
420
421 actor_update_fn = gradient_update(
422 _actor_loss,
423 self.optimizer,
424 dp_axis_name=self.dp_axis_name,
425 has_aux=False,
426 )
427
428 actor_loss, actor_params, actor_opt_state = actor_update_fn(
429 actor_opt_state, agent_state.params.actor_params
430 )
431
432 # soft update all target networks
433 target_actor_params = soft_target_update(
434 agent_state.params.target_actor_params,
435 agent_state.params.actor_params,
436 self.config.tau,
437 )
438 target_critic1_params = soft_target_update(
439 agent_state.params.target_critic1_params,
440 agent_state.params.critic1_params,
441 self.config.tau,
442 )
443 target_critic2_params = soft_target_update(
444 agent_state.params.target_critic2_params,
445 agent_state.params.critic2_params,
446 self.config.tau,
447 )
448
449 agent_state = agent_state.replace(
450 params=agent_state.params.replace(
451 actor_params=actor_params,
452 target_actor_params=target_actor_params,
453 target_critic1_params=target_critic1_params,
454 target_critic2_params=target_critic2_params,
455 )
456 )
457
458 opt_state = opt_state.replace(actor=actor_opt_state)
459
460 train_info = PyTreeDict(
461 actor_loss=actor_loss,
462 )
463
464 return (
465 train_info,
466 agent_state,
467 opt_state,
468 )
469
470 def _dummy_update_actor_fn(agent_state, opt_state, sample_batch, key):
471 return (
472 PyTreeDict(actor_loss=MISSING_LOSS),
473 agent_state,
474 opt_state,
475 )
476
477 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
478
479 critic_train_info, agent_state, opt_state = _update_critic_fn(
480 agent_state, opt_state, sample_batch, critic_key
481 )
482
483 # Note: using cond prohibits the parallel training with vmap
484 (
485 actor_train_info,
486 agent_state,
487 opt_state,
488 ) = jax.lax.cond(
489 iterations % self.config.actor_update_interval == 0,
490 _update_actor_fn,
491 _dummy_update_actor_fn,
492 agent_state,
493 opt_state,
494 sample_batch,
495 actor_key,
496 )
497
498 train_metrics = TD3TrainMetric(
499 **actor_train_info,
500 **critic_train_info,
501 ).all_reduce(dp_axis_name=self.dp_axis_name)
502
503 # calculate the number of timestep
504 sampled_timesteps = psum(
505 jnp.uint32(self.config.rollout_length * self.config.num_envs),
506 axis_name=self.dp_axis_name,
507 )
508
509 # iterations is the number of updates of the agent
510 workflow_metrics = state.metrics.replace(
511 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
512 iterations=iterations,
513 ).all_reduce(dp_axis_name=self.dp_axis_name)
514
515 return train_metrics, state.replace(
516 key=key,
517 metrics=workflow_metrics,
518 agent_state=agent_state,
519 env_state=env_state,
520 replay_buffer_state=replay_buffer_state,
521 opt_state=opt_state,
522 )
523
[docs]
524 def learn(self, state: State) -> State:
525 num_devices = jax.device_count()
526 one_step_timesteps = self.config.rollout_length * self.config.num_envs
527 sampled_timesteps = state.metrics.sampled_timesteps.tolist()
528 num_iters = math.ceil(
529 (self.config.total_timesteps - sampled_timesteps)
530 / (one_step_timesteps * self.config.fold_iters * num_devices)
531 )
532 start_iteration = state.metrics.iterations.tolist()
533 final_iteration = num_iters + start_iteration
534
535 for i in range(num_iters):
536 train_metrics, state = self._multi_steps(state)
537 workflow_metrics = state.metrics
538
539 # current iteration
540 iterations = state.metrics.iterations.tolist()
541
542 train_metrics = jtu.tree_map(
543 lambda x: None if x == MISSING_LOSS else x, train_metrics
544 )
545
546 self.recorder.write(train_metrics.to_local_dict(), iterations)
547 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
548
549 if (
550 iterations % self.config.eval_interval == 0
551 or iterations == final_iteration
552 ):
553 eval_metrics, state = self.evaluate(state)
554 self.recorder.write(
555 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
556 )
557
558 saved_state = state
559 if not self.config.save_replay_buffer:
560 saved_state = skip_replay_buffer_state(saved_state)
561 self.checkpoint_manager.save(
562 iterations,
563 saved_state,
564 force=iterations == final_iteration,
565 )
566
567 return state