Source code for evorl.algorithms.td3

  1import logging
  2from typing import Any
  3
  4import chex
  5import flax.linen as nn
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10from omegaconf import DictConfig
 11
 12from evorl.distributed import psum, pmean
 13from evorl.distributed.gradients import agent_gradient_update
 14from evorl.envs import AutoresetMode, Box, create_env, Space
 15from evorl.evaluators import Evaluator
 16from evorl.metrics import MetricBase, metric_field
 17from evorl.networks import make_policy_network, make_q_network
 18from evorl.rollout import rollout
 19from evorl.replay_buffers import ReplayBuffer
 20from evorl.sample_batch import SampleBatch
 21from evorl.types import (
 22    Action,
 23    LossDict,
 24    Params,
 25    PolicyExtraInfo,
 26    PyTreeData,
 27    PyTreeDict,
 28    State,
 29    pytree_field,
 30)
 31from evorl.utils import running_statistics
 32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
 33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 34from evorl.agent import Agent, AgentState
 35
 36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
 37
 38logger = logging.getLogger(__name__)
 39
 40
[docs] 41class TD3TrainMetric(MetricBase): 42 critic_loss: chex.Array 43 actor_loss: chex.Array 44 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
45 46
[docs] 47class TD3NetworkParams(PyTreeData): 48 """Contains training state for the learner.""" 49 50 actor_params: Params 51 critic_params: Params 52 target_actor_params: Params 53 target_critic_params: Params
54 55
[docs] 56class TD3Agent(Agent): 57 """The Agnet for TD3.""" 58 59 critic_network: nn.Module 60 actor_network: nn.Module 61 obs_preprocessor: Any = pytree_field(default=None, static=True) 62 63 discount: float = 0.99 64 exploration_epsilon: float = 0.5 65 policy_noise: float = 0.2 66 clip_policy_noise: float = 0.5 67 critics_in_actor_loss: str = "first" # or "min" 68 69 @property 70 def normalize_obs(self): 71 return self.obs_preprocessor is not None 72
[docs] 73 def init( 74 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 75 ) -> AgentState: 76 key, q_key, actor_key = jax.random.split(key, num=3) 77 78 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 79 dummy_action = action_space.sample(key)[None, ...] 80 81 critic_params = self.critic_network.init(q_key, dummy_obs, dummy_action) 82 target_critic_params = critic_params 83 84 actor_params = self.actor_network.init(actor_key, dummy_obs) 85 target_actor_params = actor_params 86 87 params_state = TD3NetworkParams( 88 critic_params=critic_params, 89 actor_params=actor_params, 90 target_critic_params=target_critic_params, 91 target_actor_params=target_actor_params, 92 ) 93 94 if self.normalize_obs: 95 # Note: statistics are broadcasted to [T*B] 96 obs_preprocessor_state = running_statistics.init_state( 97 tree_get(dummy_obs, 0) 98 ) 99 else: 100 obs_preprocessor_state = None 101 102 return AgentState( 103 params=params_state, 104 obs_preprocessor_state=obs_preprocessor_state, 105 )
106
[docs] 107 def compute_actions( 108 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 109 ) -> tuple[Action, PolicyExtraInfo]: 110 # sample_barch: [#env, ...] 111 112 obs = sample_batch.obs 113 if self.normalize_obs: 114 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 115 116 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 117 # add random noise 118 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon 119 actions += noise 120 actions = jnp.clip(actions, -1.0, 1.0) 121 122 return actions, PyTreeDict()
123
[docs] 124 def evaluate_actions( 125 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 126 ) -> tuple[Action, PolicyExtraInfo]: 127 # sample_barch: [#env, ...] 128 129 obs = sample_batch.obs 130 if self.normalize_obs: 131 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 132 133 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 134 135 return actions, PyTreeDict()
136
[docs] 137 def critic_loss( 138 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 139 ) -> LossDict: 140 """Critic loss in TD3. 141 142 Args: 143 sample_barch: [B, ...] 144 145 Return: LossDict[ 146 actor_loss 147 critic_loss 148 actor_entropy_loss 149 ] 150 """ 151 next_obs = sample_batch.extras.env_extras.ori_obs 152 obs = sample_batch.obs 153 actions = sample_batch.actions 154 155 if self.normalize_obs: 156 next_obs = self.obs_preprocessor( 157 next_obs, agent_state.obs_preprocessor_state 158 ) 159 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 160 161 next_actions = self.actor_network.apply( 162 agent_state.params.target_actor_params, next_obs 163 ) 164 next_actions += jnp.clip( 165 jax.random.normal(key, actions.shape) * self.policy_noise, 166 -self.clip_policy_noise, 167 self.clip_policy_noise, 168 ) 169 # Note: when calculating the critic loss, we also clip the actions to the action space 170 next_actions = jnp.clip(next_actions, -1.0, 1.0) 171 172 # [B, num_critics] 173 next_qs = self.critic_network.apply( 174 agent_state.params.target_critic_params, next_obs, next_actions 175 ) 176 next_qs_min = next_qs.min(-1) 177 178 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 179 180 qs_target = sample_batch.rewards + discounts * next_qs_min 181 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2)) 182 qs_target = jax.lax.stop_gradient(qs_target) 183 184 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions) 185 186 # q_loss = optax.huber_loss(qs, qs_target).sum(-1).mean() 187 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean() 188 189 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
190
[docs] 191 def actor_loss( 192 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 193 ) -> LossDict: 194 """Actor loss in TD3. 195 196 Args: 197 sample_barch: [B, ...] 198 199 Return: LossDict[ 200 actor_loss 201 critic_loss 202 actor_entropy_loss 203 ] 204 """ 205 obs = sample_batch.obs 206 207 if self.normalize_obs: 208 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 209 210 # [T*B, A] 211 # Note: when calculating the actor loss, we don't clip the actions to the action space, following the impl of SB3 and CleanRL 212 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 213 214 # TODO: handle redundant computation 215 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions) 216 217 if self.critics_in_actor_loss == "first": 218 actor_loss = -jnp.mean(qs[..., 0]) 219 elif self.critics_in_actor_loss == "min": 220 # using min_Q, like SAC 221 actor_loss = -jnp.mean(qs.min(-1)) 222 else: 223 raise ValueError( 224 f"Invalid value for critics_in_actor_loss: {self.critics_in_actor_loss}, should be 'first' or 'mean'" 225 ) 226 227 return PyTreeDict(actor_loss=actor_loss)
228 229
[docs] 230def make_mlp_td3_agent( 231 action_space: Space, 232 norm_layer_type: str = "none", 233 num_critics: int = 2, 234 critic_hidden_layer_sizes: tuple[int] = (256, 256), 235 actor_hidden_layer_sizes: tuple[int] = (256, 256), 236 discount: float = 0.99, 237 exploration_epsilon: float = 0.5, 238 policy_noise: float = 0.2, 239 clip_policy_noise: float = 0.5, 240 critics_in_actor_loss: str = "first", # or "min" 241 normalize_obs: bool = False, 242 policy_obs_key: str = "", 243 value_obs_key: str = "", 244): 245 assert isinstance(action_space, Box), "Only continue action space is supported." 246 247 action_size = action_space.shape[0] 248 249 critic_network = make_q_network( 250 n_stack=num_critics, 251 hidden_layer_sizes=critic_hidden_layer_sizes, 252 norm_layer_type=norm_layer_type, 253 obs_key=value_obs_key, 254 ) 255 actor_network = make_policy_network( 256 action_size=action_size, 257 hidden_layer_sizes=actor_hidden_layer_sizes, 258 activation_final=nn.tanh, 259 norm_layer_type=norm_layer_type, 260 obs_key=policy_obs_key, 261 ) 262 263 if normalize_obs: 264 obs_preprocessor = running_statistics.normalize 265 else: 266 obs_preprocessor = None 267 268 return TD3Agent( 269 critic_network=critic_network, 270 actor_network=actor_network, 271 obs_preprocessor=obs_preprocessor, 272 discount=discount, 273 exploration_epsilon=exploration_epsilon, 274 policy_noise=policy_noise, 275 clip_policy_noise=clip_policy_noise, 276 critics_in_actor_loss=critics_in_actor_loss, 277 )
278 279
[docs] 280class TD3Workflow(OffPolicyWorkflowTemplate):
[docs] 281 @classmethod 282 def name(cls): 283 return "TD3"
284 285 @classmethod 286 def _build_from_config(cls, config: DictConfig): 287 env = create_env( 288 config.env, 289 episode_length=config.env.max_episode_steps, 290 parallel=config.num_envs, 291 autoreset_mode=AutoresetMode.NORMAL, 292 record_ori_obs=True, 293 ) 294 295 agent = make_mlp_td3_agent( 296 action_space=env.action_space, 297 norm_layer_type=config.agent_network.norm_layer_type, 298 num_critics=config.agent_network.num_critics, 299 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 300 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 301 discount=config.discount, 302 exploration_epsilon=config.exploration_epsilon, 303 policy_noise=config.policy_noise, 304 clip_policy_noise=config.clip_policy_noise, 305 critics_in_actor_loss=config.critics_in_actor_loss, 306 normalize_obs=config.normalize_obs, 307 policy_obs_key=config.agent_network.policy_obs_key, 308 value_obs_key=config.agent_network.value_obs_key, 309 ) 310 311 # one optimizer, two opt_states (in setup function) for both actor and critic 312 if ( 313 config.optimizer.grad_clip_norm is not None 314 and config.optimizer.grad_clip_norm > 0 315 ): 316 optimizer = optax.chain( 317 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 318 optax.adam(config.optimizer.lr), 319 ) 320 else: 321 optimizer = optax.adam(config.optimizer.lr) 322 323 replay_buffer = ReplayBuffer( 324 capacity=config.replay_buffer_capacity, 325 min_sample_timesteps=max( 326 config.batch_size, config.learning_start_timesteps 327 ), 328 sample_batch_size=config.batch_size, 329 ) 330 331 eval_env = create_env( 332 config.env, 333 episode_length=config.env.max_episode_steps, 334 parallel=config.num_eval_envs, 335 autoreset_mode=AutoresetMode.DISABLED, 336 ) 337 338 evaluator = Evaluator( 339 env=eval_env, 340 action_fn=agent.evaluate_actions, 341 max_episode_steps=config.env.max_episode_steps, 342 ) 343 344 return cls( 345 env, 346 agent, 347 optimizer, 348 evaluator, 349 replay_buffer, 350 config, 351 ) 352 353 def _setup_agent_and_optimizer( 354 self, key: chex.PRNGKey 355 ) -> tuple[AgentState, chex.ArrayTree]: 356 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 357 opt_state = PyTreeDict( 358 actor=self.optimizer.init(agent_state.params.actor_params), 359 critic=self.optimizer.init(agent_state.params.critic_params), 360 ) 361 return agent_state, opt_state 362
[docs] 363 def step(self, state: State) -> tuple[MetricBase, State]: 364 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 365 366 # the trajectory [T, B, ...] 367 trajectory, env_state = rollout( 368 env_fn=self.env.step, 369 action_fn=self.agent.compute_actions, 370 env_state=state.env_state, 371 agent_state=state.agent_state, 372 key=rollout_key, 373 rollout_length=self.config.rollout_length, 374 env_extra_fields=("ori_obs", "termination"), 375 ) 376 377 trajectory_dones = trajectory.dones 378 trajectory = clean_trajectory(trajectory) 379 trajectory = flatten_rollout_trajectory(trajectory) 380 trajectory = tree_stop_gradient(trajectory) 381 382 agent_state = state.agent_state 383 if agent_state.obs_preprocessor_state is not None: 384 agent_state = agent_state.replace( 385 obs_preprocessor_state=running_statistics.update( 386 agent_state.obs_preprocessor_state, 387 trajectory.obs, 388 dp_axis_name=self.dp_axis_name, 389 ) 390 ) 391 392 replay_buffer_state = self.replay_buffer.add( 393 state.replay_buffer_state, trajectory 394 ) 395 396 def critic_loss_fn(agent_state, sample_batch, key): 397 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 398 399 loss = loss_dict.critic_loss 400 return loss, loss_dict 401 402 def actor_loss_fn(agent_state, sample_batch, key): 403 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 404 405 loss = loss_dict.actor_loss 406 return loss, loss_dict 407 408 critic_update_fn = agent_gradient_update( 409 critic_loss_fn, 410 self.optimizer, 411 dp_axis_name=self.dp_axis_name, 412 has_aux=True, 413 attach_fn=lambda agent_state, critic_params: agent_state.replace( 414 params=agent_state.params.replace(critic_params=critic_params) 415 ), 416 detach_fn=lambda agent_state: agent_state.params.critic_params, 417 ) 418 419 actor_update_fn = agent_gradient_update( 420 actor_loss_fn, 421 self.optimizer, 422 dp_axis_name=self.dp_axis_name, 423 has_aux=True, 424 attach_fn=lambda agent_state, actor_params: agent_state.replace( 425 params=agent_state.params.replace(actor_params=actor_params) 426 ), 427 detach_fn=lambda agent_state: agent_state.params.actor_params, 428 ) 429 430 def _sample_and_update_fn(carry, unused_t): 431 key, agent_state, opt_state = carry 432 433 critic_opt_state = opt_state.critic 434 actor_opt_state = opt_state.actor 435 436 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4) 437 438 if self.config.actor_update_interval - 1 > 0: 439 440 def _sample_and_update_critic_fn(carry, unused_t): 441 key, agent_state, critic_opt_state = carry 442 443 key, rb_key, critic_key = jax.random.split(key, num=3) 444 # it's safe to use read-only replay_buffer_state here. 445 sample_batch = self.replay_buffer.sample( 446 replay_buffer_state, rb_key 447 ) 448 449 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 450 critic_update_fn( 451 critic_opt_state, agent_state, sample_batch, critic_key 452 ) 453 ) 454 455 return (key, agent_state, critic_opt_state), None 456 457 key, critic_multiple_update_key = jax.random.split(key) 458 459 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 460 _sample_and_update_critic_fn, 461 (critic_multiple_update_key, agent_state, critic_opt_state), 462 (), 463 length=self.config.actor_update_interval - 1, 464 ) 465 466 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 467 468 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 469 critic_update_fn( 470 critic_opt_state, agent_state, sample_batch, critic_key 471 ) 472 ) 473 474 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 475 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key) 476 ) 477 478 target_actor_params = soft_target_update( 479 agent_state.params.target_actor_params, 480 agent_state.params.actor_params, 481 self.config.tau, 482 ) 483 target_critic_params = soft_target_update( 484 agent_state.params.target_critic_params, 485 agent_state.params.critic_params, 486 self.config.tau, 487 ) 488 agent_state = agent_state.replace( 489 params=agent_state.params.replace( 490 target_actor_params=target_actor_params, 491 target_critic_params=target_critic_params, 492 ) 493 ) 494 495 opt_state = opt_state.replace( 496 actor=actor_opt_state, critic=critic_opt_state 497 ) 498 499 return ( 500 (key, agent_state, opt_state), 501 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 502 ) 503 504 ( 505 (_, agent_state, opt_state), 506 ( 507 critic_loss, 508 actor_loss, 509 critic_loss_dict, 510 actor_loss_dict, 511 ), 512 ) = scan_and_mean( 513 _sample_and_update_fn, 514 (learn_key, agent_state, state.opt_state), 515 (), 516 length=self.config.num_updates_per_iter, 517 ) 518 519 train_metrics = TD3TrainMetric( 520 actor_loss=actor_loss, 521 critic_loss=critic_loss, 522 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 523 ).all_reduce(dp_axis_name=self.dp_axis_name) 524 525 # calculate the number of timestep 526 sampled_timesteps = psum( 527 jnp.uint32(self.config.rollout_length * self.config.num_envs), 528 axis_name=self.dp_axis_name, 529 ) 530 531 sampled_epsiodes = psum( 532 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 533 ) 534 535 # iterations is the number of updates of the agent 536 workflow_metrics = state.metrics.replace( 537 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 538 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 539 iterations=state.metrics.iterations + 1, 540 ).all_reduce(dp_axis_name=self.dp_axis_name) 541 542 return train_metrics, state.replace( 543 key=key, 544 metrics=workflow_metrics, 545 agent_state=agent_state, 546 env_state=env_state, 547 replay_buffer_state=replay_buffer_state, 548 opt_state=opt_state, 549 )