Source code for evorl.algorithms.erl.erl_td3.erl_td3_workflow

  1import logging
  2from omegaconf import DictConfig
  3import math
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10
 11from evorl.agent import AgentState, Agent
 12from evorl.distributed import agent_gradient_update
 13from evorl.types import PyTreeDict, State
 14from evorl.utils.jax_utils import (
 15    right_shift_with_padding,
 16    tree_stop_gradient,
 17    scan_and_last,
 18    scan_and_mean,
 19)
 20from evorl.utils.rl_toolkits import (
 21    flatten_rollout_trajectory,
 22    flatten_pop_rollout_episode,
 23    soft_target_update,
 24)
 25from evorl.recorders import get_1d_array_statistics
 26from evorl.algorithms.td3 import TD3TrainMetric, TD3NetworkParams
 27
 28from ..erl_workflow import ERLWorkflowBase, ERLTrainMetric
 29
 30logger = logging.getLogger(__name__)
 31
 32
[docs] 33class ERLTD3WorkflowTemplate(ERLWorkflowBase): 34 """A template for ERL workflow on TD3 Agent.""" 35 36 # Note: Turn off the warmup logging in PBT or parallel training 37 LOGGING_WARMUP_FLAG = True 38 39 def __init__(self, **kwargs): 40 super().__init__(**kwargs) 41 self._rl_update_fn = build_erl_rl_update_fn( 42 self.agent, 43 self.optimizer, 44 self.config, 45 self.agent_state_vmap_axes, 46 ) 47
[docs] 48 def setup(self, key: chex.PRNGKey) -> State: 49 state = super().setup(key) 50 51 # Note: we assume 52 if self.config.warmup_iters > 0: 53 logger.info("Start warmup") 54 55 def _warmup_step(state, unused_t): 56 train_metrics, state = self.warmup_step(state) 57 return state, train_metrics 58 59 def _logging(train_metrics, iters): 60 if self.LOGGING_WARMUP_FLAG: 61 train_metrics_dict = train_metrics.to_local_dict() 62 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 63 train_metrics_dict["pop_episode_returns"], histogram=True 64 ) 65 66 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 67 train_metrics_dict["pop_episode_lengths"], histogram=True 68 ) 69 del train_metrics_dict["rl_episode_lengths"] 70 del train_metrics_dict["rl_episode_returns"] 71 del train_metrics_dict["rl_metrics"] 72 self.recorder.write(train_metrics_dict, state.metrics.iterations) 73 74 num_fold_iters = math.floor( 75 self.config.warmup_iters / self.config.eval_interval 76 ) 77 last_fold_iters = self.config.warmup_iters % self.config.eval_interval 78 79 for i in range(num_fold_iters): 80 state, train_metrics = scan_and_last( 81 _warmup_step, state, (), length=self.config.eval_interval 82 ) 83 _logging(train_metrics, state.metrics.iterations) 84 85 if last_fold_iters > 0: 86 state, train_metrics = scan_and_last( 87 _warmup_step, state, (), length=last_fold_iters 88 ) 89 _logging(train_metrics, state.metrics.iterations) 90 91 logger.info("Complete warmup") 92 93 return state
94
[docs] 95 def warmup_step(self, state: State) -> tuple[ERLTrainMetric, State]: 96 pop_size = self.config.pop_size 97 agent_state = state.agent_state 98 ec_opt_state = state.ec_opt_state 99 replay_buffer_state = state.replay_buffer_state 100 101 key, ec_rollout_key = jax.random.split(state.key, 2) 102 103 # 1. ask() 104 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 105 # 2. evaluate() 106 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params) 107 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout( 108 pop_agent_state, replay_buffer_state, ec_rollout_key 109 ) 110 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1) 111 # 3. tell() 112 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 113 114 train_metrics = ERLTrainMetric( 115 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1), 116 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1), 117 ec_info=ec_metrics, 118 rb_size=replay_buffer_state.buffer_size, 119 ) 120 121 # calculate the number of timestep 122 sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 123 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 124 125 workflow_metrics = state.metrics.replace( 126 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 127 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 128 iterations=state.metrics.iterations + 1, 129 ) 130 131 state = state.replace( 132 key=key, 133 metrics=workflow_metrics, 134 replay_buffer_state=replay_buffer_state, 135 ec_opt_state=ec_opt_state, 136 ) 137 138 return train_metrics, state
139 140 def _ec_rollout(self, agent_state, replay_buffer_state, key): 141 return rollout_episode( 142 agent_state, 143 replay_buffer_state, 144 key, 145 collector=self.ec_collector, 146 replay_buffer=self.replay_buffer, 147 agent_state_vmap_axes=self.agent_state_vmap_axes, 148 num_agents=self.config.pop_size, 149 num_episodes=self.config.episodes_for_fitness, 150 ) 151 152 def _rl_rollout(self, agent_state, replay_buffer_state, key): 153 return rollout_episode( 154 agent_state, 155 replay_buffer_state, 156 key, 157 collector=self.rl_collector, 158 replay_buffer=self.replay_buffer, 159 agent_state_vmap_axes=self.agent_state_vmap_axes, 160 num_agents=self.config.num_rl_agents, 161 num_episodes=self.config.rollout_episodes, 162 ) 163 164 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key): 165 def _sample_fn(key): 166 return self.replay_buffer.sample(replay_buffer_state, key) 167 168 def _sample_and_update_fn(carry, unused_t): 169 key, agent_state, opt_state = carry 170 171 key, rb_key, learn_key = jax.random.split(key, 3) 172 173 rb_keys = jax.random.split( 174 rb_key, self.config.actor_update_interval * self.config.num_rl_agents 175 ) 176 sample_batches = jax.vmap(_sample_fn)(rb_keys) 177 178 # (actor_update_interval, num_rl_agents, B, ...) 179 sample_batches = jtu.tree_map( 180 lambda x: x.reshape( 181 ( 182 self.config.actor_update_interval, 183 self.config.num_rl_agents, 184 *x.shape[1:], 185 ) 186 ), 187 sample_batches, 188 ) 189 190 (agent_state, opt_state), train_info = self._rl_update_fn( 191 agent_state, opt_state, sample_batches, learn_key 192 ) 193 194 return (key, agent_state, opt_state), train_info 195 196 ( 197 (_, agent_state, opt_state), 198 ( 199 critic_loss, 200 actor_loss, 201 critic_loss_dict, 202 actor_loss_dict, 203 ), 204 ) = scan_and_mean( 205 _sample_and_update_fn, 206 (key, agent_state, opt_state), 207 (), 208 length=self.config.num_rl_updates_per_iter, 209 ) 210 211 # smoothed td3 metrics 212 td3_metrics = TD3TrainMetric( 213 actor_loss=actor_loss, 214 critic_loss=critic_loss, 215 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 216 ) 217 218 return td3_metrics, agent_state, opt_state
219 220
[docs] 221def erl_replace_td3_actor_params( 222 agent_state: AgentState, pop_actor_params: TD3NetworkParams 223) -> AgentState: 224 return agent_state.replace( 225 params=TD3NetworkParams( 226 actor_params=pop_actor_params, 227 target_actor_params=pop_actor_params, 228 critic_params=None, 229 target_critic_params=None, 230 ) 231 )
232 233 234DUMMY_TD3_TRAINMETRIC = TD3TrainMetric( 235 critic_loss=jnp.zeros(()), 236 actor_loss=jnp.zeros(()), 237 raw_loss_dict=PyTreeDict( 238 critic_loss=jnp.zeros(()), 239 q_value=jnp.zeros(()), 240 actor_loss=jnp.zeros(()), 241 ), 242) 243 244
[docs] 245def create_dummy_td3_trainmetric(num: int) -> TD3TrainMetric: 246 if num >= 1: 247 return DUMMY_TD3_TRAINMETRIC.replace( 248 raw_loss_dict=jtu.tree_map( 249 lambda x: jnp.broadcast_to(x, (num, *x.shape)), 250 DUMMY_TD3_TRAINMETRIC.raw_loss_dict, 251 ) 252 ) 253 else: 254 raise ValueError(f"num should be positive, got {num}")
255 256
[docs] 257def rollout_episode( 258 agent_state: AgentState, 259 replay_buffer_state, 260 key, 261 *, 262 collector, 263 replay_buffer, 264 agent_state_vmap_axes, 265 num_episodes, 266 num_agents, 267): 268 eval_metrics, trajectory = jax.vmap( 269 collector.rollout, 270 in_axes=(agent_state_vmap_axes, 0, None), 271 )( 272 agent_state, 273 jax.random.split(key, num_agents), 274 num_episodes, 275 ) 276 277 # [n, T, B, ...] -> [T, n*B, ...] 278 trajectory = trajectory.replace(next_obs=None) 279 trajectory = flatten_pop_rollout_episode(trajectory) 280 281 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1)) 282 trajectory = trajectory.replace(dones=None) 283 trajectory, mask = tree_stop_gradient( 284 flatten_rollout_trajectory((trajectory, mask)) 285 ) 286 replay_buffer_state = replay_buffer.add(replay_buffer_state, trajectory, mask) 287 288 return eval_metrics, trajectory, replay_buffer_state
289 290
[docs] 291def build_erl_rl_update_fn( 292 agent: Agent, 293 optimizer: optax.GradientTransformation, 294 config: DictConfig, 295 agent_state_vmap_axes: AgentState, 296): 297 """K (actor, critic) pairs.""" 298 num_rl_agents = config.num_rl_agents 299 300 def critic_loss_fn(agent_state, sample_batch, key): 301 # loss on a single critic with multiple actors 302 # sample_batch: (n, B, ...) 303 304 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 305 agent_state, sample_batch, jax.random.split(key, num_rl_agents) 306 ) 307 308 loss = loss_dict.critic_loss.sum() 309 310 return loss, loss_dict 311 312 def actor_loss_fn(agent_state, sample_batch, key): 313 # loss on a single actor 314 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 315 agent_state, sample_batch, jax.random.split(key, num_rl_agents) 316 ) 317 318 loss = loss_dict.actor_loss.sum() 319 320 return loss, loss_dict 321 322 critic_update_fn = agent_gradient_update( 323 critic_loss_fn, 324 optimizer, 325 has_aux=True, 326 attach_fn=lambda agent_state, critic_params: agent_state.replace( 327 params=agent_state.params.replace(critic_params=critic_params) 328 ), 329 detach_fn=lambda agent_state: agent_state.params.critic_params, 330 ) 331 332 actor_update_fn = agent_gradient_update( 333 actor_loss_fn, 334 optimizer, 335 has_aux=True, 336 attach_fn=lambda agent_state, actor_params: agent_state.replace( 337 params=agent_state.params.replace(actor_params=actor_params) 338 ), 339 detach_fn=lambda agent_state: agent_state.params.actor_params, 340 ) 341 342 def _update_fn(agent_state, opt_state, sample_batches, key): 343 critic_opt_state = opt_state.critic 344 actor_opt_state = opt_state.actor 345 346 key, critic_key, actor_key = jax.random.split(key, num=3) 347 348 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches) 349 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches) 350 351 if config.actor_update_interval - 1 > 0: 352 353 def _update_critic_fn(carry, sample_batch): 354 key, agent_state, critic_opt_state = carry 355 356 key, critic_key = jax.random.split(key) 357 358 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 359 critic_update_fn( 360 critic_opt_state, agent_state, sample_batch, critic_key 361 ) 362 ) 363 364 return (key, agent_state, critic_opt_state), None 365 366 key, critic_multiple_update_key = jax.random.split(key) 367 368 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 369 _update_critic_fn, 370 ( 371 critic_multiple_update_key, 372 agent_state, 373 critic_opt_state, 374 ), 375 critic_sample_batches, 376 length=config.actor_update_interval - 1, 377 ) 378 379 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 380 critic_update_fn( 381 critic_opt_state, agent_state, last_sample_batch, critic_key 382 ) 383 ) 384 385 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn( 386 actor_opt_state, agent_state, last_sample_batch, actor_key 387 ) 388 389 # not need vmap 390 target_actor_params = soft_target_update( 391 agent_state.params.target_actor_params, 392 agent_state.params.actor_params, 393 config.tau, 394 ) 395 target_critic_params = soft_target_update( 396 agent_state.params.target_critic_params, 397 agent_state.params.critic_params, 398 config.tau, 399 ) 400 agent_state = agent_state.replace( 401 params=agent_state.params.replace( 402 target_actor_params=target_actor_params, 403 target_critic_params=target_critic_params, 404 ) 405 ) 406 407 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state) 408 409 return ( 410 (agent_state, opt_state), 411 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 412 ) 413 414 return _update_fn