Source code for evorl.algorithms.impala

  1import logging
  2import math
  3from functools import partial
  4from typing import Any
  5
  6import chex
  7import flax.linen as nn
  8import jax
  9import jax.numpy as jnp
 10import jax.tree_util as jtu
 11import optax
 12from omegaconf import DictConfig
 13
 14from evorl.distributed import agent_gradient_update, psum
 15from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
 16from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
 17from evorl.evaluators import Evaluator
 18from evorl.metrics import TrainMetric, MetricBase
 19from evorl.networks import make_policy_network, make_v_network
 20from evorl.rollout import rollout
 21from evorl.sample_batch import SampleBatch
 22from evorl.types import (
 23    MISSING_REWARD,
 24    Action,
 25    LossDict,
 26    Params,
 27    PolicyExtraInfo,
 28    PyTreeData,
 29    PyTreeDict,
 30    State,
 31    pytree_field,
 32)
 33from evorl.utils import running_statistics
 34from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean, tree_get
 35from evorl.utils.rl_toolkits import average_episode_discount_return, approximate_kl
 36from evorl.workflows import OnPolicyWorkflow
 37from evorl.agent import Agent, AgentState
 38from evorl.recorders import add_prefix
 39
 40
 41logger = logging.getLogger(__name__)
 42
 43
[docs] 44class IMPALANetworkParams(PyTreeData): 45 """Contains training state for the learner.""" 46 47 policy_params: Params 48 value_params: Params
49 50 51# class IMPALATrainMetric(TrainMetric): 52# rho: chex.Array = jnp.zeros((), dtype=jnp.float32) 53 54
[docs] 55class IMPALAAgent(Agent): 56 continuous_action: bool 57 policy_network: nn.Module 58 value_network: nn.Module 59 obs_preprocessor: Any = pytree_field(default=None, static=True) 60 61 discount: float = 0.99 62 vtrace_lambda: float = 1.0 63 clip_rho_threshold: float = 1.0 64 clip_c_threshold: float = 1.0 65 clip_pg_rho_threshold: float = 1.0 66 adv_mode: str = pytree_field(default="official", static=True) 67 68 @property 69 def normalize_obs(self): 70 return self.obs_preprocessor is not None 71
[docs] 72 def init( 73 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 74 ) -> AgentState: 75 policy_key, value_key = jax.random.split(key, 2) 76 77 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 78 79 policy_params = self.policy_network.init(policy_key, dummy_obs) 80 81 value_params = self.value_network.init(value_key, dummy_obs) 82 83 params_state = IMPALANetworkParams( 84 policy_params=policy_params, value_params=value_params 85 ) 86 87 if self.normalize_obs: 88 # Note: statistics are broadcasted to [T*B] 89 obs_preprocessor_state = running_statistics.init_state( 90 tree_get(dummy_obs, 0) 91 ) 92 else: 93 obs_preprocessor_state = None 94 95 return AgentState( 96 params=params_state, obs_preprocessor_state=obs_preprocessor_state 97 )
98
[docs] 99 def compute_actions( 100 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 101 ) -> tuple[Action, PolicyExtraInfo]: 102 obs = sample_batch.obs 103 if self.normalize_obs: 104 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 105 106 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 107 108 if self.continuous_action: 109 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 110 else: 111 actions_dist = get_categorical_dist(raw_actions) 112 113 actions = actions_dist.sample(seed=key) 114 115 policy_extras = PyTreeDict( 116 # Log probabilities of the selected actions for importance sampling 117 logp=actions_dist.log_prob(actions) 118 ) 119 120 return actions, policy_extras
121
[docs] 122 def evaluate_actions( 123 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 124 ) -> tuple[Action, PolicyExtraInfo]: 125 obs = sample_batch.obs 126 if self.normalize_obs: 127 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 128 129 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 130 131 if self.continuous_action: 132 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 133 else: 134 actions_dist = get_categorical_dist(raw_actions) 135 136 actions = actions_dist.mode() 137 138 return actions, PyTreeDict()
139
[docs] 140 def loss( 141 self, agent_state: AgentState, trajectory: SampleBatch, key: chex.PRNGKey 142 ) -> LossDict: 143 """IMPALA loss. 144 145 Args: 146 trajectory: [T, B, ...] 147 a sequence of transitions, not shuffled timesteps 148 149 """ 150 # mask invalid transitions at autoreset 151 mask = jnp.logical_not(trajectory.extras.env_extras.autoreset) 152 153 obs = trajectory.obs 154 _obs = jtu.tree_map( 155 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0), 156 trajectory.obs, 157 trajectory.next_obs, 158 ) 159 if self.normalize_obs: 160 _obs = self.obs_preprocessor(_obs, agent_state.obs_preprocessor_state) 161 162 vs = self.value_network.apply(agent_state.params.value_params, _obs) 163 164 behavior_actions_logp = trajectory.extras.policy_extras.logp 165 behavior_actions = trajectory.actions 166 167 # [T, B, A] 168 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 169 170 if self.continuous_action: 171 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 172 else: 173 actions_dist = get_categorical_dist(raw_actions) 174 175 # [T, B] 176 actions_logp = actions_dist.log_prob(behavior_actions) 177 logrho = actions_logp - behavior_actions_logp 178 rho = jnp.exp(logrho) 179 180 # TODO: consider PEB: truncation in the middle of trajectory 181 # hint: use IS of td-error with 182 vtrace = compute_vtrace( 183 rho_t=rho, 184 v_t=vs[:-1], 185 v_t_plus_1=vs[1:], 186 rewards=trajectory.rewards, 187 dones=trajectory.dones, 188 terminations=trajectory.extras.env_extras.termination, 189 discount=self.discount, 190 lambda_=self.vtrace_lambda, 191 clip_rho_threshold=self.clip_rho_threshold, 192 clip_c_threshold=self.clip_c_threshold, 193 ) 194 195 vtrace = jax.lax.stop_gradient(vtrace) 196 197 # ======= critic ======= 198 199 critic_loss = optax.squared_error(vs[:-1], vtrace).mean(where=mask) 200 201 # ====== actor ======= 202 203 # GAE-V: [T*B] 204 pg_advantages = compute_pg_advantage( 205 vtrace=vtrace, 206 v_t=vs[:-1], 207 v_t_plus_1=vs[1:], 208 rewards=trajectory.rewards, 209 terminations=trajectory.extras.env_extras.termination, 210 discount=self.discount, 211 lambda_=self.vtrace_lambda, 212 mode=self.adv_mode, 213 ) 214 215 clipped_pg_rho_t = jnp.minimum(self.clip_pg_rho_threshold, rho) 216 pg_advantage = clipped_pg_rho_t * pg_advantages 217 pg_advantage = jax.lax.stop_gradient(pg_advantage) 218 219 policy_loss = -(pg_advantage * actions_logp).mean(where=mask) 220 221 # entropy: [T*B] 222 if self.continuous_action: 223 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask) 224 else: 225 actor_entropy = actions_dist.entropy().mean(where=mask) 226 227 approx_kl = approximate_kl(logrho).mean() 228 229 return PyTreeDict( 230 actor_loss=policy_loss, 231 critic_loss=critic_loss, 232 actor_entropy=actor_entropy, 233 rho=rho.mean(where=mask), 234 approx_kl=approx_kl, 235 )
236 237
[docs] 238def make_mlp_impala_agent( 239 action_space: Space, 240 discount: float = 0.99, 241 vtrace_lambda: float = 1.0, 242 clip_rho_threshold: float = 1.0, 243 clip_c_threshold: float = 1.0, 244 clip_pg_rho_threshold: float = 1.0, 245 adv_mode: str = "official", 246 actor_hidden_layer_sizes: tuple[int] = (256, 256), 247 critic_hidden_layer_sizes: tuple[int] = (256, 256), 248 normalize_obs: bool = False, 249 policy_obs_key: str = "", 250 value_obs_key: str = "", 251): 252 if isinstance(action_space, Box): 253 action_size = action_space.shape[0] * 2 254 continuous_action = True 255 elif isinstance(action_space, Discrete): 256 action_size = action_space.n 257 continuous_action = False 258 else: 259 raise NotImplementedError(f"Unsupported action space: {action_space}") 260 261 policy_network = make_policy_network( 262 action_size=action_size, 263 hidden_layer_sizes=actor_hidden_layer_sizes, 264 obs_key=policy_obs_key, 265 ) 266 267 value_network = make_v_network( 268 hidden_layer_sizes=critic_hidden_layer_sizes, 269 obs_key=value_obs_key, 270 ) 271 272 if normalize_obs: 273 obs_preprocessor = running_statistics.normalize 274 else: 275 obs_preprocessor = None 276 277 return IMPALAAgent( 278 continuous_action=continuous_action, 279 policy_network=policy_network, 280 value_network=value_network, 281 obs_preprocessor=obs_preprocessor, 282 discount=discount, 283 vtrace_lambda=vtrace_lambda, 284 clip_rho_threshold=clip_rho_threshold, 285 clip_c_threshold=clip_c_threshold, 286 clip_pg_rho_threshold=clip_pg_rho_threshold, 287 adv_mode=adv_mode, 288 )
289 290
[docs] 291class IMPALAWorkflow(OnPolicyWorkflow): 292 """Syncrhonous version of IMPALA (A2C|PPO w/ V-Trace).""" 293
[docs] 294 @classmethod 295 def name(cls): 296 return "IMPALA"
297 298 @classmethod 299 def _rescale_config(cls, config: DictConfig) -> None: 300 num_devices = jax.device_count() 301 if config.num_envs % num_devices != 0: 302 logger.warning( 303 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), " 304 f"rescale num_envs to {config.num_envs // num_devices}" 305 ) 306 if config.num_eval_envs % num_devices != 0: 307 logger.warning( 308 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), " 309 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}" 310 ) 311 if config.minibatch_size % num_devices != 0: 312 logger.warning( 313 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), " 314 f"rescale minibatch_size to {config.minibatch_size // num_devices}" 315 ) 316 317 config.num_envs = config.num_envs // num_devices 318 config.num_eval_envs = config.num_eval_envs // num_devices 319 config.minibatch_size = config.minibatch_size // num_devices 320 321 @classmethod 322 def _build_from_config(cls, config: DictConfig): 323 max_episode_steps = config.env.max_episode_steps 324 325 env = create_env( 326 config.env, 327 episode_length=max_episode_steps, 328 parallel=config.num_envs, 329 autoreset_mode=AutoresetMode.ENVPOOL, 330 ) 331 332 # Maybe need a discount array for different agents 333 agent = make_mlp_impala_agent( 334 action_space=env.action_space, 335 discount=config.discount, 336 vtrace_lambda=config.vtrace_lambda, 337 clip_rho_threshold=config.clip_rho_threshold, 338 clip_c_threshold=config.clip_c_threshold, 339 clip_pg_rho_threshold=config.clip_pg_rho_threshold, 340 adv_mode=config.adv_mode, 341 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 342 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 343 normalize_obs=config.normalize_obs, 344 policy_obs_key=config.agent_network.policy_obs_key, 345 value_obs_key=config.agent_network.value_obs_key, 346 ) 347 348 if ( 349 config.optimizer.grad_clip_norm is not None 350 and config.optimizer.grad_clip_norm > 0 351 ): 352 optimizer = optax.chain( 353 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 354 optax.adam(config.optimizer.lr), 355 ) 356 else: 357 optimizer = optax.adam(config.optimizer.lr) 358 359 eval_env = create_env( 360 config.env, 361 episode_length=max_episode_steps, 362 parallel=config.num_eval_envs, 363 autoreset_mode=AutoresetMode.DISABLED, 364 ) 365 366 evaluator = Evaluator( 367 env=eval_env, 368 action_fn=agent.evaluate_actions, 369 max_episode_steps=max_episode_steps, 370 ) 371 372 return cls(env, agent, optimizer, evaluator, config) 373
[docs] 374 def step(self, state: State) -> tuple[MetricBase, State]: 375 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 376 377 trajectory, env_state = rollout( 378 self.env.step, 379 self.agent.compute_actions, 380 state.env_state, 381 state.agent_state, 382 rollout_key, 383 rollout_length=self.config.rollout_length, 384 env_extra_fields=("autoreset", "episode_return", "termination"), 385 ) 386 387 agent_state = state.agent_state 388 if agent_state.obs_preprocessor_state is not None: 389 agent_state = agent_state.replace( 390 obs_preprocessor_state=running_statistics.update( 391 agent_state.obs_preprocessor_state, 392 trajectory.obs, 393 dp_axis_name=self.dp_axis_name, 394 ) 395 ) 396 397 train_episode_return = average_episode_discount_return( 398 trajectory.extras.env_extras.episode_return, 399 trajectory.dones, 400 dp_axis_name=self.dp_axis_name, 401 ) 402 403 trajectory = tree_stop_gradient(trajectory) 404 405 def loss_fn(agent_state, sample_batch, key): 406 # learn all data from trajectory 407 loss_dict = self.agent.loss(agent_state, sample_batch, key) 408 loss_weights = self.config.loss_weights 409 loss = jnp.zeros(()) 410 for loss_key in loss_weights.keys(): 411 loss += loss_weights[loss_key] * loss_dict[loss_key] 412 413 return loss, loss_dict 414 415 update_fn = agent_gradient_update( 416 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True 417 ) 418 419 # minibatch_size: num of envs in one batch 420 # unit in batch: trajectory [T, B//k, ...] 421 num_minibatches = self.config.num_envs // self.config.minibatch_size 422 423 def _get_shuffled_minibatch(perm_key, x): 424 # x: [T, B, ...] -> [k, T, B//k, ...] 425 x = jax.random.permutation(perm_key, x, axis=1)[ 426 :, : num_minibatches * self.config.minibatch_size 427 ] 428 xs = jnp.stack(jnp.split(x, num_minibatches, axis=1)) 429 430 return xs 431 432 def minibatch_step(carry, trajectory): 433 opt_state, agent_state, key = carry 434 key, learn_key = jax.random.split(key) 435 436 (loss, loss_dict), agent_state, opt_state = update_fn( 437 opt_state, agent_state, trajectory, learn_key 438 ) 439 440 return (opt_state, agent_state, key), (loss, loss_dict) 441 442 def epoch_step(carry, _): 443 opt_state, agent_state, key = carry 444 shuffle_key, learn_key = jax.random.split(key) 445 batch_trajectory = jtu.tree_map( 446 partial(_get_shuffled_minibatch, shuffle_key), trajectory 447 ) 448 449 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean( 450 minibatch_step, 451 (opt_state, agent_state, learn_key), 452 batch_trajectory, 453 length=num_minibatches, 454 ) 455 456 return (opt_state, agent_state, key), (loss, loss_dict) 457 458 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean( 459 epoch_step, 460 (state.opt_state, agent_state, learn_key), 461 None, 462 length=self.config.reuse_rollout_epochs, 463 ) 464 465 # ======== update metrics ======== 466 467 sampled_timesteps = psum( 468 jnp.array( 469 self.config.rollout_length * self.config.num_envs, dtype=jnp.uint32 470 ), 471 axis_name=self.dp_axis_name, 472 ) 473 474 sampled_epsiodes = psum( 475 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 476 ) 477 478 workflow_metrics = state.metrics.replace( 479 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 480 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 481 iterations=state.metrics.iterations + 1, 482 ).all_reduce(dp_axis_name=self.dp_axis_name) 483 484 train_metrics = TrainMetric( 485 train_episode_return=train_episode_return, 486 loss=loss, 487 raw_loss_dict=loss_dict, 488 ).all_reduce(dp_axis_name=self.dp_axis_name) 489 490 return train_metrics, state.replace( 491 key=key, 492 metrics=workflow_metrics, 493 agent_state=agent_state, 494 env_state=env_state, 495 opt_state=opt_state, 496 )
497
[docs] 498 def learn(self, state: State) -> State: 499 one_step_timesteps = self.config.rollout_length * self.config.num_envs 500 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps) 501 502 start_iteration = state.metrics.iterations 503 504 for i in range(start_iteration, num_iters): 505 train_metrics, state = self.step(state) 506 workflow_metrics = state.metrics 507 508 iters = i + 1 509 510 self.recorder.write(workflow_metrics.to_local_dict(), iters) 511 train_metric_data = train_metrics.to_local_dict() 512 if train_metrics.train_episode_return == MISSING_REWARD: 513 train_metric_data["train_episode_return"] = None 514 self.recorder.write(train_metric_data, iters) 515 516 if iters % self.config.eval_interval == 0 or iters == num_iters: 517 eval_metrics, state = self.evaluate(state) 518 self.recorder.write( 519 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 520 ) 521 522 self.checkpoint_manager.save( 523 iters, 524 state, 525 force=iters == num_iters, 526 ) 527 528 return state
529 530
[docs] 531def compute_vtrace( 532 rho_t, 533 v_t, 534 v_t_plus_1, 535 rewards, 536 dones, 537 terminations, 538 discount=0.99, 539 lambda_=1.0, 540 clip_rho_threshold=1.0, 541 clip_c_threshold=1.0, 542): 543 chex.assert_trees_all_equal_shapes_and_dtypes( 544 rho_t, v_t, v_t_plus_1, rewards, dones 545 ) 546 547 # clip c and rho 548 clipped_c_t = jnp.minimum(clip_c_threshold, rho_t) * lambda_ 549 clipped_rho_t = jnp.minimum(clip_rho_threshold, rho_t) 550 551 # calculate δV_t 552 td_error = clipped_rho_t * ( 553 rewards + discount * (1 - terminations) * v_t_plus_1 - v_t 554 ) 555 556 # calculate delta = vtrace - v_t 557 def _compute_delta(delta, params): 558 td_error, discount, c = params 559 delta = td_error + discount * c * delta 560 return delta, delta 561 562 bootstrap_delta = jnp.zeros_like(v_t[-1]) 563 _, delta = jax.lax.scan( 564 _compute_delta, 565 bootstrap_delta, 566 (td_error, discount * (1 - dones), clipped_c_t), 567 reverse=True, 568 unroll=16, 569 ) 570 571 # calculate vs 572 vtrace = delta + v_t 573 574 return vtrace
575 576
[docs] 577def compute_pg_advantage( 578 vtrace, 579 v_t, 580 v_t_plus_1, 581 rewards, 582 terminations, 583 discount=0.99, 584 lambda_=1.0, 585 mode="official", 586): 587 discounts = discount * (1 - terminations) 588 # calculate advantage function 589 if mode == "official": 590 # Note: rllib also follows this implementation 591 gae_v_t_plus_1 = jnp.concatenate([vtrace[1:], v_t_plus_1[-1:]], axis=0) 592 elif mode == "acme": 593 gae_v_t_plus_1 = jnp.concatenate( 594 [lambda_ * vtrace[1:] + (1 - lambda_) * v_t[1:], v_t_plus_1[-1:]], axis=0 595 ) 596 else: 597 raise ValueError(f"mode {mode} is not supported") 598 599 q_t = rewards + discounts * gae_v_t_plus_1 600 gae_adv = q_t - v_t 601 602 return gae_adv