Source code for evorl.algorithms.contrib.td3_v2

  1import logging
  2import math
  3
  4import jax
  5import jax.numpy as jnp
  6import jax.tree_util as jtu
  7
  8from evorl.distributed import psum
  9from evorl.distributed.gradients import agent_gradient_update
 10from evorl.metrics import MetricBase
 11from evorl.rollout import rollout
 12from evorl.types import (
 13    PyTreeDict,
 14    State,
 15)
 16from evorl.utils import running_statistics
 17from evorl.utils.jax_utils import tree_stop_gradient
 18from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 19from evorl.recorders import add_prefix
 20
 21from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
 22from evorl.algorithms.td3 import TD3TrainMetric, TD3Workflow
 23
 24logger = logging.getLogger(__name__)
 25
 26MISSING_LOSS = -1e10
 27
 28
[docs] 29class TD3V2Workflow(TD3Workflow): 30 """The similar impl of TD3 in SB3 and CleanRL.""" 31
[docs] 32 @classmethod 33 def name(cls): 34 return "TD3-V2"
35
[docs] 36 def step(self, state: State) -> tuple[MetricBase, State]: 37 iterations = state.metrics.iterations + 1 38 key, rollout_key, rb_key, critic_key, actor_key = jax.random.split( 39 state.key, num=5 40 ) 41 42 # the trajectory [T, B, ...] 43 trajectory, env_state = rollout( 44 env_fn=self.env.step, 45 action_fn=self.agent.compute_actions, 46 env_state=state.env_state, 47 agent_state=state.agent_state, 48 key=rollout_key, 49 rollout_length=self.config.rollout_length, 50 env_extra_fields=("ori_obs", "termination"), 51 ) 52 53 trajectory = clean_trajectory(trajectory) 54 trajectory = flatten_rollout_trajectory(trajectory) 55 trajectory = tree_stop_gradient(trajectory) 56 57 agent_state = state.agent_state 58 opt_state = state.opt_state 59 60 if agent_state.obs_preprocessor_state is not None: 61 agent_state = agent_state.replace( 62 obs_preprocessor_state=running_statistics.update( 63 agent_state.obs_preprocessor_state, 64 trajectory.obs, 65 dp_axis_name=self.dp_axis_name, 66 ) 67 ) 68 69 replay_buffer_state = self.replay_buffer.add( 70 state.replay_buffer_state, trajectory 71 ) 72 73 def critic_loss_fn(agent_state, sample_batch, key): 74 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 75 76 loss = loss_dict.critic_loss 77 return loss, loss_dict 78 79 def actor_loss_fn(agent_state, sample_batch, key): 80 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 81 82 loss = loss_dict.actor_loss 83 return loss, loss_dict 84 85 critic_update_fn = agent_gradient_update( 86 critic_loss_fn, 87 self.optimizer, 88 dp_axis_name=self.dp_axis_name, 89 has_aux=True, 90 attach_fn=lambda agent_state, critic_params: agent_state.replace( 91 params=agent_state.params.replace(critic_params=critic_params) 92 ), 93 detach_fn=lambda agent_state: agent_state.params.critic_params, 94 ) 95 96 actor_update_fn = agent_gradient_update( 97 actor_loss_fn, 98 self.optimizer, 99 dp_axis_name=self.dp_axis_name, 100 has_aux=True, 101 attach_fn=lambda agent_state, actor_params: agent_state.replace( 102 params=agent_state.params.replace(actor_params=actor_params) 103 ), 104 detach_fn=lambda agent_state: agent_state.params.actor_params, 105 ) 106 107 def _update_critic_fn(agent_state, opt_state, sample_batch, key): 108 critic_opt_state = opt_state.critic 109 110 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 111 critic_update_fn(critic_opt_state, agent_state, sample_batch, key) 112 ) 113 114 opt_state = opt_state.replace(critic=critic_opt_state) 115 116 return ( 117 critic_loss, 118 critic_loss_dict, 119 agent_state, 120 opt_state, 121 ) 122 123 def _update_actor_fn(agent_state, opt_state, sample_batch, key): 124 actor_opt_state = opt_state.actor 125 126 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 127 actor_update_fn(actor_opt_state, agent_state, sample_batch, key) 128 ) 129 130 target_actor_params = soft_target_update( 131 agent_state.params.target_actor_params, 132 agent_state.params.actor_params, 133 self.config.tau, 134 ) 135 target_critic_params = soft_target_update( 136 agent_state.params.target_critic_params, 137 agent_state.params.critic_params, 138 self.config.tau, 139 ) 140 agent_state = agent_state.replace( 141 params=agent_state.params.replace( 142 target_actor_params=target_actor_params, 143 target_critic_params=target_critic_params, 144 ) 145 ) 146 147 opt_state = opt_state.replace(actor=actor_opt_state) 148 149 return ( 150 actor_loss, 151 actor_loss_dict, 152 agent_state, 153 opt_state, 154 ) 155 156 def _dummy_update_actor_fn(agent_state, opt_state, sample_batch, key): 157 actor_loss = jnp.full((), fill_value=MISSING_LOSS) 158 actor_loss_dict = PyTreeDict(actor_loss=actor_loss) 159 160 return ( 161 actor_loss, 162 actor_loss_dict, 163 agent_state, 164 opt_state, 165 ) 166 167 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 168 169 critic_loss, critic_loss_dict, agent_state, opt_state = _update_critic_fn( 170 agent_state, opt_state, sample_batch, critic_key 171 ) 172 173 # Note: using cond prohibits the parallel training by vmap 174 ( 175 actor_loss, 176 actor_loss_dict, 177 agent_state, 178 opt_state, 179 ) = jax.lax.cond( 180 iterations % self.config.actor_update_interval == 0, 181 _update_actor_fn, 182 _dummy_update_actor_fn, 183 agent_state, 184 opt_state, 185 sample_batch, 186 actor_key, 187 ) 188 189 train_metrics = TD3TrainMetric( 190 actor_loss=actor_loss, 191 critic_loss=critic_loss, 192 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 193 ).all_reduce(dp_axis_name=self.dp_axis_name) 194 195 # calculate the number of timestep 196 sampled_timesteps = psum( 197 jnp.uint32(self.config.rollout_length * self.config.num_envs), 198 axis_name=self.dp_axis_name, 199 ) 200 201 # iterations is the number of updates of the agent 202 workflow_metrics = state.metrics.replace( 203 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 204 iterations=state.metrics.iterations + 1, 205 ).all_reduce(dp_axis_name=self.dp_axis_name) 206 207 return train_metrics, state.replace( 208 key=key, 209 metrics=workflow_metrics, 210 agent_state=agent_state, 211 env_state=env_state, 212 replay_buffer_state=replay_buffer_state, 213 opt_state=opt_state, 214 )
215
[docs] 216 def learn(self, state: State) -> State: 217 num_devices = jax.device_count() 218 one_step_timesteps = self.config.rollout_length * self.config.num_envs 219 sampled_timesteps = state.metrics.sampled_timesteps.tolist() 220 num_iters = math.ceil( 221 (self.config.total_timesteps - sampled_timesteps) 222 / (one_step_timesteps * self.config.fold_iters * num_devices) 223 ) 224 225 start_iteration = state.metrics.iterations.tolist() 226 final_iteration = num_iters + start_iteration 227 228 for i in range(num_iters): 229 train_metrics, state = self._multi_steps(state) 230 workflow_metrics = state.metrics 231 232 # current iteration 233 iterations = state.metrics.iterations.tolist() 234 235 train_metrics = jtu.tree_map( 236 lambda x: None if x == MISSING_LOSS else x, train_metrics 237 ) 238 239 self.recorder.write(train_metrics.to_local_dict(), iterations) 240 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 241 242 if ( 243 iterations % self.config.eval_interval == 0 244 or iterations == final_iteration 245 ): 246 eval_metrics, state = self.evaluate(state) 247 self.recorder.write( 248 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 249 ) 250 251 saved_state = state 252 if not self.config.save_replay_buffer: 253 saved_state = skip_replay_buffer_state(saved_state) 254 self.checkpoint_manager.save( 255 iterations, saved_state, force=iterations == final_iteration 256 ) 257 258 return state