Source code for evorl.algorithms.contrib.pop_td3

  1import logging
  2from functools import partial
  3import math
  4from typing_extensions import Self  # pytype: disable=not-supported-yet]
  5from omegaconf import DictConfig
  6
  7import chex
  8import jax
  9import jax.numpy as jnp
 10import jax.tree_util as jtu
 11
 12from evorl.distributed import (
 13    agent_gradient_update,
 14    psum,
 15)
 16from evorl.agent import AgentState, RandomAgent
 17from evorl.types import PyTreeDict, State
 18from evorl.metrics import MetricBase, EvaluateMetric
 19from evorl.rollout import rollout
 20from evorl.sample_batch import SampleBatch
 21from evorl.utils import running_statistics
 22from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean
 23from evorl.utils.rl_toolkits import soft_target_update, flatten_rollout_trajectory
 24from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
 25
 26from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
 27from evorl.algorithms.td3 import TD3TrainMetric, TD3Workflow
 28
 29
 30logger = logging.getLogger(__name__)
 31
 32
[docs] 33class WorkflowMetric(MetricBase): 34 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32) 35 sampled_timesteps_per_agent: chex.Array = jnp.zeros((), dtype=jnp.uint32) 36 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 37 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
38 39
[docs] 40class PopTD3Workflow(TD3Workflow): 41 """Indepentent TD3 agent with shared replay buffer.""" 42
[docs] 43 @classmethod 44 def name(cls): 45 return "PopTD3"
46
[docs] 47 @classmethod 48 def build_from_config( 49 cls, 50 config: DictConfig, 51 enable_multi_devices: bool = False, 52 enable_jit: bool = True, 53 ) -> Self: 54 devices = jax.local_devices() 55 56 if enable_multi_devices or len(devices) > 1: 57 raise NotImplementedError("Multi-devices is not supported yet.") 58 59 return super().build_from_config(config, enable_multi_devices, enable_jit)
60 61 def _setup_workflow_metrics(self) -> MetricBase: 62 return WorkflowMetric() 63
[docs] 64 def setup(self, key: chex.PRNGKey) -> State: 65 key, agent_key, env_key, rb_key = jax.random.split(key, 4) 66 67 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key) 68 workflow_metrics = self._setup_workflow_metrics() 69 70 # TODO: what about using shared init env_state? 71 env_state = jax.vmap(self.env.reset)( 72 jax.random.split(env_key, self.config.pop_size) 73 ) 74 replay_buffer_state = self._setup_replaybuffer(rb_key) 75 76 state = State( 77 key=key, 78 metrics=workflow_metrics, 79 agent_state=agent_state, 80 env_state=env_state, 81 opt_state=opt_state, 82 replay_buffer_state=replay_buffer_state, 83 ) 84 85 logger.info("Start replay buffer post-setup") 86 87 state = self._postsetup_replaybuffer(state) 88 89 logger.info("Complete replay buffer post-setup") 90 91 return state
92 93 def _setup_agent_and_optimizer( 94 self, key: chex.PRNGKey 95 ) -> tuple[AgentState, chex.ArrayTree]: 96 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))( 97 self.env.obs_space, 98 self.env.action_space, 99 jax.random.split(key, self.config.pop_size), 100 ) 101 102 def _opt_init(agent_state): 103 return PyTreeDict( 104 actor=self.optimizer.init(agent_state.params.actor_params), 105 critic=self.optimizer.init(agent_state.params.critic_params), 106 ) 107 108 opt_state = jax.vmap(_opt_init)(agent_state) 109 110 return agent_state, opt_state 111 112 def _postsetup_replaybuffer(self, state: State) -> State: 113 action_space = self.env.action_space 114 obs_space = self.env.obs_space 115 config = self.config 116 replay_buffer_state = state.replay_buffer_state 117 agent_state = state.agent_state 118 119 def _rollout(agent, agent_state, key, rollout_length): 120 env_key, rollout_key = jax.random.split(key) 121 122 env_state = self.env.reset(env_key) 123 124 trajectory, env_state = rollout( 125 env_fn=self.env.step, 126 action_fn=agent.compute_actions, 127 env_state=env_state, 128 agent_state=agent_state, 129 key=rollout_key, 130 rollout_length=rollout_length, 131 env_extra_fields=("ori_obs", "termination"), 132 ) 133 134 # [T, B, ...] -> [T*B, ...] 135 trajectory = clean_trajectory(trajectory) 136 trajectory = flatten_rollout_trajectory(trajectory) 137 trajectory = tree_stop_gradient(trajectory) 138 139 return trajectory 140 141 def _update_obs_preprocessor(agent_state, trajectory): 142 if ( 143 agent_state.obs_preprocessor_state is not None 144 and len(trajectory.obs) > 0 145 ): 146 agent_state = agent_state.replace( 147 obs_preprocessor_state=running_statistics.update( 148 agent_state.obs_preprocessor_state, 149 trajectory.obs, 150 dp_axis_name=self.dp_axis_name, 151 ) 152 ) 153 return agent_state 154 155 # ==== fill random transitions ==== 156 157 key, random_rollout_key, rollout_key = jax.random.split(state.key, num=3) 158 random_agent = RandomAgent() 159 random_agent_state = random_agent.init( 160 obs_space, action_space, jax.random.PRNGKey(0) 161 ) 162 rollout_length = config.random_timesteps // config.num_envs 163 164 trajectory = _rollout( 165 random_agent, 166 random_agent_state, 167 key=random_rollout_key, 168 rollout_length=rollout_length, 169 ) 170 171 agent_state = jax.vmap(_update_obs_preprocessor, in_axes=(0, None))( 172 agent_state, trajectory 173 ) 174 175 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 176 177 rollout_timesteps = rollout_length * config.num_envs 178 sampled_timesteps = psum( 179 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name 180 ) 181 182 # ==== fill tansition state from init agents (diff from TD3) ==== 183 rollout_length = math.ceil( 184 (config.learning_start_timesteps - rollout_timesteps) 185 / (config.num_envs * config.pop_size) 186 ) 187 188 _vmap_rollout = jax.vmap( 189 partial(_rollout, self.agent, rollout_length=rollout_length) 190 ) 191 192 trajectory = _vmap_rollout( 193 agent_state, jax.random.split(rollout_key, config.pop_size) 194 ) 195 agent_state = jax.vmap(_update_obs_preprocessor)(agent_state, trajectory) 196 197 # [#pop, T*B] -> [#pop*T*B, ...] 198 trajectory = flatten_rollout_trajectory(trajectory) 199 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory) 200 201 rollout_timesteps = rollout_length * config.num_envs * config.pop_size 202 sampled_timesteps = sampled_timesteps + psum( 203 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name 204 ) 205 206 workflow_metrics = state.metrics.replace( 207 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 208 ).all_reduce(dp_axis_name=self.dp_axis_name) 209 210 return state.replace( 211 key=key, 212 metrics=workflow_metrics, 213 agent_state=agent_state, 214 replay_buffer_state=replay_buffer_state, 215 ) 216
[docs] 217 def step(self, state: State) -> tuple[MetricBase, State]: 218 pop_size = self.config.pop_size 219 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 220 221 _rollout = partial( 222 rollout, 223 self.env.step, 224 self.agent.compute_actions, 225 rollout_length=self.config.rollout_length, 226 env_extra_fields=("ori_obs", "termination"), 227 ) 228 229 # the trajectory [#pop, T, B, ...] 230 trajectory, env_state = jax.vmap(_rollout)( 231 state.env_state, state.agent_state, jax.random.split(rollout_key, pop_size) 232 ) 233 234 trajectory_dones = trajectory.dones 235 trajectory = clean_trajectory(trajectory) 236 trajectory = flatten_pop_rollout_trajectory(trajectory) 237 trajectory = tree_stop_gradient(trajectory) 238 239 agent_state = state.agent_state 240 if agent_state.obs_preprocessor_state is not None: 241 agent_state = agent_state.replace( 242 obs_preprocessor_state=running_statistics.update( 243 agent_state.obs_preprocessor_state, 244 trajectory.obs, 245 dp_axis_name=self.dp_axis_name, 246 ) 247 ) 248 249 replay_buffer_state = self.replay_buffer.add( 250 state.replay_buffer_state, trajectory 251 ) 252 253 def critic_loss_fn(agent_state, sample_batch, key): 254 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 255 256 loss = loss_dict.critic_loss 257 return loss, loss_dict 258 259 def actor_loss_fn(agent_state, sample_batch, key): 260 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 261 262 loss = loss_dict.actor_loss 263 return loss, loss_dict 264 265 critic_update_fn = agent_gradient_update( 266 critic_loss_fn, 267 self.optimizer, 268 dp_axis_name=self.dp_axis_name, 269 has_aux=True, 270 attach_fn=lambda agent_state, critic_params: agent_state.replace( 271 params=agent_state.params.replace(critic_params=critic_params) 272 ), 273 detach_fn=lambda agent_state: agent_state.params.critic_params, 274 ) 275 276 actor_update_fn = agent_gradient_update( 277 actor_loss_fn, 278 self.optimizer, 279 dp_axis_name=self.dp_axis_name, 280 has_aux=True, 281 attach_fn=lambda agent_state, actor_params: agent_state.replace( 282 params=agent_state.params.replace(actor_params=actor_params) 283 ), 284 detach_fn=lambda agent_state: agent_state.params.actor_params, 285 ) 286 287 critic_update_fn = jax.vmap(critic_update_fn, in_axes=(0, 0, None, 0)) 288 actor_update_fn = jax.vmap(actor_update_fn, in_axes=(0, 0, None, 0)) 289 290 def _sample_and_update_fn(carry, unused_t): 291 key, agent_state, opt_state = carry 292 293 critic_opt_state = opt_state.critic 294 actor_opt_state = opt_state.actor 295 296 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4) 297 298 if self.config.actor_update_interval - 1 > 0: 299 300 def _sample_and_update_critic_fn(carry, unused_t): 301 key, agent_state, critic_opt_state = carry 302 303 key, rb_key, critic_key = jax.random.split(key, num=3) 304 # it's safe to use read-only replay_buffer_state here. 305 sample_batch = self.replay_buffer.sample( 306 replay_buffer_state, rb_key 307 ) 308 309 critic_key = jax.random.split(critic_key, pop_size) 310 311 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 312 critic_update_fn( 313 critic_opt_state, agent_state, sample_batch, critic_key 314 ) 315 ) 316 317 return (key, agent_state, critic_opt_state), None 318 319 key, critic_multiple_update_key = jax.random.split(key) 320 321 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 322 _sample_and_update_critic_fn, 323 (critic_multiple_update_key, agent_state, critic_opt_state), 324 (), 325 length=self.config.actor_update_interval - 1, 326 ) 327 328 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 329 330 critic_key = jax.random.split(critic_key, pop_size) 331 actor_key = jax.random.split(actor_key, pop_size) 332 333 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 334 critic_update_fn( 335 critic_opt_state, agent_state, sample_batch, critic_key 336 ) 337 ) 338 339 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 340 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key) 341 ) 342 343 target_actor_params = soft_target_update( 344 agent_state.params.target_actor_params, 345 agent_state.params.actor_params, 346 self.config.tau, 347 ) 348 target_critic_params = soft_target_update( 349 agent_state.params.target_critic_params, 350 agent_state.params.critic_params, 351 self.config.tau, 352 ) 353 agent_state = agent_state.replace( 354 params=agent_state.params.replace( 355 target_actor_params=target_actor_params, 356 target_critic_params=target_critic_params, 357 ) 358 ) 359 360 opt_state = opt_state.replace( 361 actor=actor_opt_state, critic=critic_opt_state 362 ) 363 364 return ( 365 (key, agent_state, opt_state), 366 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 367 ) 368 369 ( 370 (_, agent_state, opt_state), 371 ( 372 critic_loss, 373 actor_loss, 374 critic_loss_dict, 375 actor_loss_dict, 376 ), 377 ) = scan_and_mean( 378 _sample_and_update_fn, 379 (learn_key, agent_state, state.opt_state), 380 (), 381 length=self.config.num_updates_per_iter, 382 ) 383 384 # [#pop, ...] 385 train_metrics = TD3TrainMetric( 386 actor_loss=actor_loss, 387 critic_loss=critic_loss, 388 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 389 ).all_reduce(dp_axis_name=self.dp_axis_name) 390 391 # calculate the number of timestep 392 sampled_timesteps_per_agent = psum( 393 jnp.uint32(self.config.rollout_length * self.config.num_envs), 394 axis_name=self.dp_axis_name, 395 ) 396 sampled_timesteps = psum( 397 jnp.uint32( 398 self.config.rollout_length * self.config.num_envs * self.config.pop_size 399 ), 400 axis_name=self.dp_axis_name, 401 ) 402 sampled_epsiodes = psum( 403 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 404 ) 405 406 # iterations is the number of updates of the agent 407 workflow_metrics = state.metrics.replace( 408 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 409 sampled_timesteps_per_agent=state.metrics.sampled_timesteps_per_agent 410 + sampled_timesteps_per_agent, 411 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 412 iterations=state.metrics.iterations + 1, 413 ).all_reduce(dp_axis_name=self.dp_axis_name) 414 415 return train_metrics, state.replace( 416 key=key, 417 metrics=workflow_metrics, 418 agent_state=agent_state, 419 env_state=env_state, 420 replay_buffer_state=replay_buffer_state, 421 opt_state=opt_state, 422 )
423
[docs] 424 def evaluate(self, state: State) -> tuple[MetricBase, State]: 425 key, eval_key = jax.random.split(state.key, num=2) 426 427 # [#pop, #episodes] 428 raw_eval_metrics = jax.vmap( 429 partial(self.evaluator.evaluate, num_episodes=self.config.eval_episodes), 430 )( 431 state.agent_state, 432 jax.random.split(eval_key, self.config.pop_size), 433 ) 434 435 eval_metrics = EvaluateMetric( 436 episode_returns=raw_eval_metrics.episode_returns.mean(-1), 437 episode_lengths=raw_eval_metrics.episode_lengths.mean(-1), 438 ).all_reduce(dp_axis_name=self.dp_axis_name) 439 440 state = state.replace(key=key) 441 return eval_metrics, state
442
[docs] 443 def learn(self, state: State) -> State: 444 num_devices = jax.device_count() 445 one_step_timesteps = ( 446 self.config.rollout_length * self.config.num_envs * self.config.pop_size 447 ) 448 sampled_timesteps = state.metrics.sampled_timesteps.tolist() 449 num_iters = math.ceil( 450 (self.config.total_timesteps - sampled_timesteps) 451 / (one_step_timesteps * self.config.fold_iters * num_devices) 452 ) 453 start_iteration = state.metrics.iterations.tolist() 454 final_iteration = num_iters + start_iteration 455 456 for i in range(num_iters): 457 train_metrics, state = self._multi_steps(state) 458 workflow_metrics = state.metrics 459 460 # current iteration 461 iterations = state.metrics.iterations.tolist() 462 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 463 464 train_metrics_dict = jtu.tree_map( 465 partial(get_1d_array_statistics, histogram=True), 466 train_metrics.to_local_dict(), 467 ) 468 469 self.recorder.write(train_metrics_dict, iterations) 470 471 if ( 472 iterations % self.config.eval_interval == 0 473 or iterations == final_iteration 474 ): 475 eval_metrics, state = self.evaluate(state) 476 477 eval_metrics_dict = jtu.tree_map( 478 get_1d_array, 479 eval_metrics.to_local_dict(), 480 ) 481 482 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iterations) 483 484 saved_state = state 485 if not self.config.save_replay_buffer: 486 saved_state = skip_replay_buffer_state(saved_state) 487 self.checkpoint_manager.save( 488 iterations, 489 saved_state, 490 force=iterations == final_iteration, 491 ) 492 493 return state
494 495
[docs] 496def flatten_pop_rollout_trajectory(trajectory: SampleBatch) -> SampleBatch: 497 """Flatten the trajectory from [#pop, T, B, ...] to [#pop*T*B, ...].""" 498 return jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 3), trajectory)