Source code for evorl.algorithms.contrib.td3_v3

  1import logging
  2import math
  3from typing import Any
  4
  5import chex
  6import flax.linen as nn
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11from omegaconf import DictConfig
 12
 13from evorl.replay_buffers import ReplayBuffer
 14from evorl.distributed import psum
 15from evorl.distributed.gradients import gradient_update
 16from evorl.envs import AutoresetMode, Box, create_env, Space
 17from evorl.evaluators import Evaluator
 18from evorl.metrics import MetricBase
 19from evorl.rollout import rollout
 20from evorl.networks import make_policy_network, make_q_network
 21from evorl.sample_batch import SampleBatch
 22from evorl.types import (
 23    Action,
 24    Params,
 25    PyTreeData,
 26    PyTreeDict,
 27    PolicyExtraInfo,
 28    State,
 29    pytree_field,
 30)
 31from evorl.utils import running_statistics
 32from evorl.utils.jax_utils import tree_stop_gradient, tree_get
 33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 34from evorl.agent import AgentState, Agent
 35from evorl.recorders import add_prefix
 36
 37from ..offpolicy_utils import (
 38    OffPolicyWorkflowTemplate,
 39    clean_trajectory,
 40    skip_replay_buffer_state,
 41)
 42
 43
 44logger = logging.getLogger(__name__)
 45
 46MISSING_LOSS = -1e10
 47
 48
[docs] 49class TD3Agent(Agent): 50 """The Agnet for TD3.""" 51 52 critic_network: nn.Module 53 actor_network: nn.Module 54 obs_preprocessor: Any = pytree_field(default=None, static=True) 55 56 discount: float = 0.99 57 exploration_epsilon: float = 0.5 58 policy_noise: float = 0.2 59 clip_policy_noise: float = 0.5 60 critics_in_actor_loss: str = "first" # or "min" 61 62 @property 63 def normalize_obs(self): 64 return self.obs_preprocessor is not None 65
[docs] 66 def init( 67 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 68 ) -> AgentState: 69 key, q_key1, q_key2, actor_key = jax.random.split(key, num=4) 70 71 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 72 dummy_action = action_space.sample(key)[None, ...] 73 74 critic1_params = self.critic_network.init(q_key1, dummy_obs, dummy_action) 75 target_critic1_params = critic1_params 76 critic2_params = self.critic_network.init(q_key2, dummy_obs, dummy_action) 77 target_critic2_params = critic2_params 78 79 actor_params = self.actor_network.init(actor_key, dummy_obs) 80 target_actor_params = actor_params 81 82 params_state = TD3NetworkParams( 83 actor_params=actor_params, 84 target_actor_params=target_actor_params, 85 critic1_params=critic1_params, 86 target_critic1_params=target_critic1_params, 87 critic2_params=critic2_params, 88 target_critic2_params=target_critic2_params, 89 ) 90 91 if self.normalize_obs: 92 # Note: statistics are broadcasted to [T*B] 93 obs_preprocessor_state = running_statistics.init_state( 94 tree_get(dummy_obs, 0) 95 ) 96 else: 97 obs_preprocessor_state = None 98 99 return AgentState( 100 params=params_state, 101 obs_preprocessor_state=obs_preprocessor_state, 102 )
103
[docs] 104 def compute_actions( 105 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 106 ) -> tuple[Action, PolicyExtraInfo]: 107 obs = sample_batch.obs 108 if self.normalize_obs: 109 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 110 111 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 112 # add random noise 113 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon 114 actions += noise 115 actions = jnp.clip(actions, -1.0, 1.0) 116 117 return actions, PyTreeDict()
118
[docs] 119 def evaluate_actions( 120 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 121 ) -> tuple[Action, PolicyExtraInfo]: 122 obs = sample_batch.obs 123 if self.normalize_obs: 124 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 125 126 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 127 128 return actions, PyTreeDict()
129 130
[docs] 131def make_mlp_td3_agent( 132 action_space: Space, 133 norm_layer_type: str = "none", 134 critic_hidden_layer_sizes: tuple[int] = (256, 256), 135 actor_hidden_layer_sizes: tuple[int] = (256, 256), 136 discount: float = 0.99, 137 exploration_epsilon: float = 0.5, 138 policy_noise: float = 0.2, 139 clip_policy_noise: float = 0.5, 140 critics_in_actor_loss: str = "first", # or "min" 141 normalize_obs: bool = False, 142): 143 assert isinstance(action_space, Box), "Only continue action space is supported." 144 145 action_size = action_space.shape[0] 146 147 critic_network = make_q_network( 148 n_stack=1, 149 hidden_layer_sizes=critic_hidden_layer_sizes, 150 norm_layer_type=norm_layer_type, 151 ) 152 actor_network = make_policy_network( 153 action_size=action_size, 154 hidden_layer_sizes=actor_hidden_layer_sizes, 155 activation_final=nn.tanh, 156 norm_layer_type=norm_layer_type, 157 ) 158 159 if normalize_obs: 160 obs_preprocessor = running_statistics.normalize 161 else: 162 obs_preprocessor = None 163 164 return TD3Agent( 165 critic_network=critic_network, 166 actor_network=actor_network, 167 obs_preprocessor=obs_preprocessor, 168 discount=discount, 169 exploration_epsilon=exploration_epsilon, 170 policy_noise=policy_noise, 171 clip_policy_noise=clip_policy_noise, 172 critics_in_actor_loss=critics_in_actor_loss, 173 )
174 175
[docs] 176class TD3TrainMetric(MetricBase): 177 actor_loss: chex.Array 178 critic1_loss: chex.Array 179 critic2_loss: chex.Array 180 q1: chex.Array 181 q2: chex.Array
182 183
[docs] 184class TD3NetworkParams(PyTreeData): 185 """Contains training state for the learner.""" 186 187 actor_params: Params 188 critic1_params: Params 189 critic2_params: Params 190 target_actor_params: Params 191 target_critic1_params: Params 192 target_critic2_params: Params
193 194
[docs] 195class TD3V3Workflow(OffPolicyWorkflowTemplate): 196 """The similar impl of TD3 in SB3 and CleanRL.""" 197
[docs] 198 @classmethod 199 def name(cls): 200 return "TD3-V3"
201 202 @classmethod 203 def _build_from_config(cls, config: DictConfig): 204 env = create_env( 205 config.env, 206 episode_length=config.env.max_episode_steps, 207 parallel=config.num_envs, 208 autoreset_mode=AutoresetMode.NORMAL, 209 record_ori_obs=True, 210 ) 211 212 assert isinstance(env.action_space, Box), ( 213 "Only continue action space is supported." 214 ) 215 216 agent = make_mlp_td3_agent( 217 action_space=env.action_space, 218 norm_layer_type=config.agent_network.norm_layer_type, 219 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 220 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 221 discount=config.discount, 222 exploration_epsilon=config.exploration_epsilon, 223 policy_noise=config.policy_noise, 224 clip_policy_noise=config.clip_policy_noise, 225 critics_in_actor_loss=config.critics_in_actor_loss, 226 normalize_obs=config.normalize_obs, 227 ) 228 229 # one optimizer, two opt_states (in setup function) for both actor and critic 230 if ( 231 config.optimizer.grad_clip_norm is not None 232 and config.optimizer.grad_clip_norm > 0 233 ): 234 optimizer = optax.chain( 235 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 236 optax.adam(config.optimizer.lr), 237 ) 238 else: 239 optimizer = optax.adam(config.optimizer.lr) 240 241 replay_buffer = ReplayBuffer( 242 capacity=config.replay_buffer_capacity, 243 min_sample_timesteps=max( 244 config.batch_size, config.learning_start_timesteps 245 ), 246 sample_batch_size=config.batch_size, 247 ) 248 249 eval_env = create_env( 250 config.env, 251 episode_length=config.env.max_episode_steps, 252 parallel=config.num_eval_envs, 253 autoreset_mode=AutoresetMode.DISABLED, 254 ) 255 256 evaluator = Evaluator( 257 env=eval_env, 258 action_fn=agent.evaluate_actions, 259 max_episode_steps=config.env.max_episode_steps, 260 ) 261 262 return cls( 263 env, 264 agent, 265 optimizer, 266 evaluator, 267 replay_buffer, 268 config, 269 ) 270 271 def _setup_agent_and_optimizer( 272 self, key: chex.PRNGKey 273 ) -> tuple[AgentState, chex.ArrayTree]: 274 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 275 opt_state = PyTreeDict( 276 actor=self.optimizer.init(agent_state.params.actor_params), 277 critic1=self.optimizer.init(agent_state.params.critic1_params), 278 critic2=self.optimizer.init(agent_state.params.critic2_params), 279 ) 280 return agent_state, opt_state 281
[docs] 282 def step(self, state: State) -> tuple[MetricBase, State]: 283 iterations = state.metrics.iterations + 1 284 key, rollout_key, critic_key, actor_key, rb_key = jax.random.split( 285 state.key, num=5 286 ) 287 288 # the trajectory [T, B, ...] 289 trajectory, env_state = rollout( 290 env_fn=self.env.step, 291 action_fn=self.agent.compute_actions, 292 env_state=state.env_state, 293 agent_state=state.agent_state, 294 key=rollout_key, 295 rollout_length=self.config.rollout_length, 296 env_extra_fields=("ori_obs", "termination"), 297 ) 298 299 trajectory = clean_trajectory(trajectory) 300 trajectory = flatten_rollout_trajectory(trajectory) 301 trajectory = tree_stop_gradient(trajectory) 302 303 agent_state = state.agent_state 304 opt_state = state.opt_state 305 306 if agent_state.obs_preprocessor_state is not None: 307 agent_state = agent_state.replace( 308 obs_preprocessor_state=running_statistics.update( 309 agent_state.obs_preprocessor_state, 310 trajectory.obs, 311 dp_axis_name=self.dp_axis_name, 312 ) 313 ) 314 315 replay_buffer_state = self.replay_buffer.add( 316 state.replay_buffer_state, trajectory 317 ) 318 319 def _update_critic_fn(agent_state, opt_state, sample_batch, key): 320 critic1_opt_state = opt_state.critic1 321 critic2_opt_state = opt_state.critic2 322 323 agent = self.agent 324 325 next_obs = sample_batch.extras.env_extras.ori_obs 326 obs = sample_batch.obs 327 actions = sample_batch.actions 328 329 if agent.normalize_obs: 330 next_obs = agent.obs_preprocessor( 331 next_obs, agent_state.obs_preprocessor_state 332 ) 333 obs = agent.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 334 335 actions_next = agent.actor_network.apply( 336 agent_state.params.target_actor_params, next_obs 337 ) 338 actions_next += jnp.clip( 339 jax.random.normal(key, actions.shape) * agent.policy_noise, 340 -agent.clip_policy_noise, 341 agent.clip_policy_noise, 342 ) 343 # Note: when calculating the critic loss, we also clip the actions to the action space 344 actions_next = jnp.clip(actions_next, -1.0, 1.0) 345 346 # [B] 347 qs1_next = agent.critic_network.apply( 348 agent_state.params.target_critic1_params, next_obs, actions_next 349 ) 350 qs2_next = agent.critic_network.apply( 351 agent_state.params.target_critic2_params, next_obs, actions_next 352 ) 353 qs_next_min = jnp.minimum(qs1_next, qs2_next) 354 355 discounts = agent.discount * ( 356 1 - sample_batch.extras.env_extras.termination 357 ) 358 359 qs_target = sample_batch.rewards + discounts * qs_next_min 360 qs_target = jax.lax.stop_gradient(qs_target) 361 362 # q_loss = optax.huber_loss(qs, qs_target).sum(-1).mean() 363 def _critic_loss(params): 364 qs = agent.critic_network.apply(params, obs, actions) 365 loss = optax.squared_error(qs, qs_target).mean() 366 return loss, qs.mean() 367 368 critic_update_fn = gradient_update( 369 _critic_loss, 370 self.optimizer, 371 dp_axis_name=self.dp_axis_name, 372 has_aux=True, 373 ) 374 375 (critic1_loss, q1), ciritc1_params, critic1_opt_state = critic_update_fn( 376 critic1_opt_state, agent_state.params.critic1_params 377 ) 378 (critic2_loss, q2), ciritc2_params, critic2_opt_state = critic_update_fn( 379 critic2_opt_state, agent_state.params.critic2_params 380 ) 381 382 agent_state = agent_state.replace( 383 params=agent_state.params.replace( 384 critic1_params=ciritc1_params, critic2_params=ciritc2_params 385 ) 386 ) 387 388 opt_state = opt_state.replace( 389 critic1=critic1_opt_state, critic2=critic2_opt_state 390 ) 391 392 train_info = PyTreeDict( 393 critic1_loss=critic1_loss, 394 critic2_loss=critic2_loss, 395 q1=q1, 396 q2=q2, 397 ) 398 399 return ( 400 train_info, 401 agent_state, 402 opt_state, 403 ) 404 405 def _update_actor_fn(agent_state, opt_state, sample_batch, key): 406 actor_opt_state = opt_state.actor 407 agent = self.agent 408 obs = sample_batch.obs 409 410 if agent.normalize_obs: 411 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 412 413 # ==== actor update ==== 414 def _actor_loss(params): 415 actions = agent.actor_network.apply(params, obs) 416 loss = -agent.critic_network.apply( 417 agent_state.params.critic1_params, obs, actions 418 ).mean() 419 return loss 420 421 actor_update_fn = gradient_update( 422 _actor_loss, 423 self.optimizer, 424 dp_axis_name=self.dp_axis_name, 425 has_aux=False, 426 ) 427 428 actor_loss, actor_params, actor_opt_state = actor_update_fn( 429 actor_opt_state, agent_state.params.actor_params 430 ) 431 432 # soft update all target networks 433 target_actor_params = soft_target_update( 434 agent_state.params.target_actor_params, 435 agent_state.params.actor_params, 436 self.config.tau, 437 ) 438 target_critic1_params = soft_target_update( 439 agent_state.params.target_critic1_params, 440 agent_state.params.critic1_params, 441 self.config.tau, 442 ) 443 target_critic2_params = soft_target_update( 444 agent_state.params.target_critic2_params, 445 agent_state.params.critic2_params, 446 self.config.tau, 447 ) 448 449 agent_state = agent_state.replace( 450 params=agent_state.params.replace( 451 actor_params=actor_params, 452 target_actor_params=target_actor_params, 453 target_critic1_params=target_critic1_params, 454 target_critic2_params=target_critic2_params, 455 ) 456 ) 457 458 opt_state = opt_state.replace(actor=actor_opt_state) 459 460 train_info = PyTreeDict( 461 actor_loss=actor_loss, 462 ) 463 464 return ( 465 train_info, 466 agent_state, 467 opt_state, 468 ) 469 470 def _dummy_update_actor_fn(agent_state, opt_state, sample_batch, key): 471 return ( 472 PyTreeDict(actor_loss=MISSING_LOSS), 473 agent_state, 474 opt_state, 475 ) 476 477 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 478 479 critic_train_info, agent_state, opt_state = _update_critic_fn( 480 agent_state, opt_state, sample_batch, critic_key 481 ) 482 483 # Note: using cond prohibits the parallel training with vmap 484 ( 485 actor_train_info, 486 agent_state, 487 opt_state, 488 ) = jax.lax.cond( 489 iterations % self.config.actor_update_interval == 0, 490 _update_actor_fn, 491 _dummy_update_actor_fn, 492 agent_state, 493 opt_state, 494 sample_batch, 495 actor_key, 496 ) 497 498 train_metrics = TD3TrainMetric( 499 **actor_train_info, 500 **critic_train_info, 501 ).all_reduce(dp_axis_name=self.dp_axis_name) 502 503 # calculate the number of timestep 504 sampled_timesteps = psum( 505 jnp.uint32(self.config.rollout_length * self.config.num_envs), 506 axis_name=self.dp_axis_name, 507 ) 508 509 # iterations is the number of updates of the agent 510 workflow_metrics = state.metrics.replace( 511 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 512 iterations=iterations, 513 ).all_reduce(dp_axis_name=self.dp_axis_name) 514 515 return train_metrics, state.replace( 516 key=key, 517 metrics=workflow_metrics, 518 agent_state=agent_state, 519 env_state=env_state, 520 replay_buffer_state=replay_buffer_state, 521 opt_state=opt_state, 522 )
523
[docs] 524 def learn(self, state: State) -> State: 525 num_devices = jax.device_count() 526 one_step_timesteps = self.config.rollout_length * self.config.num_envs 527 sampled_timesteps = state.metrics.sampled_timesteps.tolist() 528 num_iters = math.ceil( 529 (self.config.total_timesteps - sampled_timesteps) 530 / (one_step_timesteps * self.config.fold_iters * num_devices) 531 ) 532 start_iteration = state.metrics.iterations.tolist() 533 final_iteration = num_iters + start_iteration 534 535 for i in range(num_iters): 536 train_metrics, state = self._multi_steps(state) 537 workflow_metrics = state.metrics 538 539 # current iteration 540 iterations = state.metrics.iterations.tolist() 541 542 train_metrics = jtu.tree_map( 543 lambda x: None if x == MISSING_LOSS else x, train_metrics 544 ) 545 546 self.recorder.write(train_metrics.to_local_dict(), iterations) 547 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 548 549 if ( 550 iterations % self.config.eval_interval == 0 551 or iterations == final_iteration 552 ): 553 eval_metrics, state = self.evaluate(state) 554 self.recorder.write( 555 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 556 ) 557 558 saved_state = state 559 if not self.config.save_replay_buffer: 560 saved_state = skip_replay_buffer_state(saved_state) 561 self.checkpoint_manager.save( 562 iterations, 563 saved_state, 564 force=iterations == final_iteration, 565 ) 566 567 return state