Source code for evorl.algorithms.contrib.td3_onpolicy

  1import logging
  2from functools import partial
  3import math
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10from omegaconf import DictConfig
 11
 12from evorl.distributed import psum
 13from evorl.distributed.gradients import agent_gradient_update
 14from evorl.envs import AutoresetMode, create_env
 15from evorl.evaluators import Evaluator
 16from evorl.metrics import MetricBase
 17from evorl.rollout import rollout
 18from evorl.types import (
 19    PyTreeDict,
 20    State,
 21)
 22from evorl.utils import running_statistics
 23from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, scan_and_last
 24from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 25from evorl.agent import Agent, AgentState
 26from evorl.recorders import add_prefix
 27from evorl.workflows import OnPolicyWorkflow
 28
 29from evorl.algorithms.td3 import make_mlp_td3_agent, clean_trajectory, TD3TrainMetric
 30
 31logger = logging.getLogger(__name__)
 32
 33MISSING_LOSS = -1e10
 34
 35
[docs] 36class TD3OnPolicyWorkflow(OnPolicyWorkflow):
[docs] 37 @classmethod 38 def name(cls): 39 return "TD3-OnPolicy"
40 41 @classmethod 42 def _rescale_config(cls, config: DictConfig) -> None: 43 num_devices = jax.device_count() 44 45 if config.num_envs % num_devices != 0: 46 logger.warning( 47 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), " 48 f"rescale num_envs to {config.num_envs // num_devices}" 49 ) 50 if config.num_eval_envs % num_devices != 0: 51 logger.warning( 52 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), " 53 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}" 54 ) 55 if config.minibatch_size % num_devices != 0: 56 logger.warning( 57 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), " 58 f"rescale minibatch_size to {config.minibatch_size // num_devices}" 59 ) 60 61 config.num_envs = config.num_envs // num_devices 62 config.num_eval_envs = config.num_eval_envs // num_devices 63 config.minibatch_size = config.minibatch_size // num_devices 64 65 @classmethod 66 def _build_from_config(cls, config: DictConfig): 67 env = create_env( 68 config.env, 69 episode_length=config.env.max_episode_steps, 70 parallel=config.num_envs, 71 autoreset_mode=AutoresetMode.NORMAL, 72 record_ori_obs=True, 73 ) 74 75 agent = make_mlp_td3_agent( 76 action_space=env.action_space, 77 norm_layer_type=config.agent_network.norm_layer_type, 78 num_critics=config.agent_network.num_critics, 79 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 80 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 81 discount=config.discount, 82 exploration_epsilon=config.exploration_epsilon, 83 policy_noise=config.policy_noise, 84 clip_policy_noise=config.clip_policy_noise, 85 critics_in_actor_loss=config.critics_in_actor_loss, 86 normalize_obs=config.normalize_obs, 87 ) 88 89 if ( 90 config.optimizer.grad_clip_norm is not None 91 and config.optimizer.grad_clip_norm > 0 92 ): 93 optimizer = optax.chain( 94 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 95 optax.adam(config.optimizer.lr), 96 ) 97 else: 98 optimizer = optax.adam(config.optimizer.lr) 99 100 eval_env = create_env( 101 config.env, 102 episode_length=config.env.max_episode_steps, 103 parallel=config.num_eval_envs, 104 autoreset_mode=AutoresetMode.DISABLED, 105 ) 106 107 one_step_rollout_steps = config.num_envs * config.rollout_length 108 if one_step_rollout_steps % config.minibatch_size != 0: 109 logger.warning( 110 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)" 111 ) 112 113 evaluator = Evaluator( 114 env=eval_env, 115 action_fn=agent.evaluate_actions, 116 max_episode_steps=config.env.max_episode_steps, 117 ) 118 119 return cls(env, agent, optimizer, evaluator, config) 120 121 def _setup_agent_and_optimizer( 122 self, key: chex.PRNGKey 123 ) -> tuple[AgentState, chex.ArrayTree]: 124 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 125 opt_state = PyTreeDict( 126 actor=self.optimizer.init(agent_state.params.actor_params), 127 critic=self.optimizer.init(agent_state.params.critic_params), 128 ) 129 return agent_state, opt_state 130
[docs] 131 def step(self, state: State) -> tuple[MetricBase, State]: 132 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 133 134 # trajectory: [T, #envs, ...] 135 trajectory, env_state = rollout( 136 self.env.step, 137 self.agent.compute_actions, 138 state.env_state, 139 state.agent_state, 140 rollout_key, 141 rollout_length=self.config.rollout_length, 142 env_extra_fields=("ori_obs", "termination"), 143 ) 144 145 agent_state = state.agent_state 146 if agent_state.obs_preprocessor_state is not None: 147 agent_state = agent_state.replace( 148 obs_preprocessor_state=running_statistics.update( 149 agent_state.obs_preprocessor_state, 150 trajectory.obs, 151 dp_axis_name=self.dp_axis_name, 152 ) 153 ) 154 155 trajectory = clean_trajectory(trajectory) 156 trajectory = flatten_rollout_trajectory(trajectory) 157 trajectory = tree_stop_gradient(trajectory) 158 # ============================ 159 160 update_fn = build_rl_update_fn(self.agent, self.optimizer, self.config) 161 162 num_minibatches = ( 163 self.config.rollout_length 164 * self.config.num_envs 165 // (self.config.minibatch_size * self.config.actor_update_interval) 166 ) 167 168 def _get_shuffled_minibatch(perm_key, x): 169 x = jax.random.permutation(perm_key, x)[ 170 : num_minibatches 171 * self.config.minibatch_size 172 * self.config.actor_update_interval 173 ] 174 return x.reshape( 175 num_minibatches, 176 self.config.actor_update_interval, 177 self.config.minibatch_size, 178 *x.shape[1:], 179 ) 180 181 def minibatch_step(carry, trajectory): 182 # trajectory: [actor_update_interval, B, ...] 183 184 opt_state, agent_state, key = carry 185 key, learn_key = jax.random.split(key) 186 187 (agent_state, opt_state), train_info = update_fn( 188 agent_state, opt_state, trajectory, learn_key 189 ) 190 191 return (opt_state, agent_state, key), train_info 192 193 def epoch_step(carry, _): 194 opt_state, agent_state, key = carry 195 perm_key, learn_key = jax.random.split(key, num=2) 196 197 (opt_state, agent_state, key), train_info = scan_and_mean( 198 minibatch_step, 199 (opt_state, agent_state, learn_key), 200 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory), 201 length=num_minibatches, 202 ) 203 204 return (opt_state, agent_state, key), train_info 205 206 # loss_list: [reuse_rollout_epochs, num_minibatches] 207 ( 208 (opt_state, agent_state, _), 209 ( 210 critic_loss, 211 actor_loss, 212 critic_loss_dict, 213 actor_loss_dict, 214 ), 215 ) = scan_and_last( 216 epoch_step, 217 (state.opt_state, agent_state, learn_key), 218 None, 219 length=self.config.reuse_rollout_epochs, 220 ) 221 222 # ======== update metrics ======== 223 224 sampled_timesteps = psum( 225 jnp.uint32(self.config.rollout_length * self.config.num_envs), 226 axis_name=self.dp_axis_name, 227 ) 228 229 workflow_metrics = state.metrics.replace( 230 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 231 iterations=state.metrics.iterations + 1, 232 ).all_reduce(dp_axis_name=self.dp_axis_name) 233 234 train_metrics = TD3TrainMetric( 235 actor_loss=actor_loss, 236 critic_loss=critic_loss, 237 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 238 ).all_reduce(dp_axis_name=self.dp_axis_name) 239 240 return train_metrics, state.replace( 241 key=key, 242 metrics=workflow_metrics, 243 agent_state=agent_state, 244 env_state=env_state, 245 opt_state=opt_state, 246 )
247
[docs] 248 def learn(self, state: State) -> State: 249 one_step_timesteps = self.config.rollout_length * self.config.num_envs 250 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps) 251 252 start_iteration = state.metrics.iterations.tolist() 253 final_iteration = num_iters + start_iteration 254 255 for i in range(num_iters): 256 train_metrics, state = self.step(state) 257 workflow_metrics = state.metrics 258 259 iterations = state.metrics.iterations.tolist() 260 261 self.recorder.write(workflow_metrics.to_local_dict(), iterations) 262 self.recorder.write(train_metrics.to_local_dict(), iterations) 263 264 if ( 265 iterations % self.config.eval_interval == 0 266 or iterations == final_iteration 267 ): 268 eval_metrics, state = self.evaluate(state) 269 self.recorder.write( 270 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations 271 ) 272 273 self.checkpoint_manager.save( 274 iterations, 275 state, 276 force=iterations == final_iteration, 277 ) 278 279 return state
280 281
[docs] 282def build_rl_update_fn( 283 agent: Agent, 284 optimizer: optax.GradientTransformation, 285 config: DictConfig, 286): 287 def critic_loss_fn(agent_state, sample_batch, key): 288 # loss on a single critic with multiple actors 289 # sample_batch: (B, ...) 290 291 loss_dict = agent.critic_loss(agent_state, sample_batch, key) 292 293 loss = loss_dict.critic_loss 294 295 return loss, loss_dict 296 297 def actor_loss_fn(agent_state, sample_batch, key): 298 # loss on a single actor 299 # different actor shares same sample_batch (B, ...) input 300 loss_dict = agent.actor_loss(agent_state, sample_batch, key) 301 302 loss = loss_dict.actor_loss 303 304 return loss, loss_dict 305 306 critic_update_fn = agent_gradient_update( 307 critic_loss_fn, 308 optimizer, 309 has_aux=True, 310 attach_fn=lambda agent_state, critic_params: agent_state.replace( 311 params=agent_state.params.replace(critic_params=critic_params) 312 ), 313 detach_fn=lambda agent_state: agent_state.params.critic_params, 314 ) 315 316 actor_update_fn = agent_gradient_update( 317 actor_loss_fn, 318 optimizer, 319 has_aux=True, 320 attach_fn=lambda agent_state, actor_params: agent_state.replace( 321 params=agent_state.params.replace(actor_params=actor_params) 322 ), 323 detach_fn=lambda agent_state: agent_state.params.actor_params, 324 ) 325 326 def _update_fn(agent_state, opt_state, sample_batches, key): 327 critic_opt_state = opt_state.critic 328 actor_opt_state = opt_state.actor 329 330 key, critic_key, actor_key = jax.random.split(key, num=3) 331 332 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches) 333 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches) 334 335 if config.actor_update_interval - 1 > 0: 336 337 def _update_critic_fn(carry, sample_batch): 338 key, agent_state, critic_opt_state = carry 339 340 key, critic_key = jax.random.split(key) 341 342 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 343 critic_update_fn( 344 critic_opt_state, agent_state, sample_batch, critic_key 345 ) 346 ) 347 348 return (key, agent_state, critic_opt_state), None 349 350 key, critic_multiple_update_key = jax.random.split(key) 351 352 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 353 _update_critic_fn, 354 ( 355 critic_multiple_update_key, 356 agent_state, 357 critic_opt_state, 358 ), 359 critic_sample_batches, 360 length=config.actor_update_interval - 1, 361 ) 362 363 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 364 critic_update_fn( 365 critic_opt_state, agent_state, last_sample_batch, critic_key 366 ) 367 ) 368 369 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn( 370 actor_opt_state, agent_state, last_sample_batch, actor_key 371 ) 372 373 # not need vmap 374 target_actor_params = soft_target_update( 375 agent_state.params.target_actor_params, 376 agent_state.params.actor_params, 377 config.tau, 378 ) 379 target_critic_params = soft_target_update( 380 agent_state.params.target_critic_params, 381 agent_state.params.critic_params, 382 config.tau, 383 ) 384 agent_state = agent_state.replace( 385 params=agent_state.params.replace( 386 target_actor_params=target_actor_params, 387 target_critic_params=target_critic_params, 388 ) 389 ) 390 391 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state) 392 393 return ( 394 (agent_state, opt_state), 395 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 396 ) 397 398 return _update_fn