Source code for evorl.algorithms.ppo

  1import logging
  2import math
  3from functools import partial
  4from typing import Any
  5from omegaconf import DictConfig
  6
  7import chex
  8import flax.linen as nn
  9import jax
 10import jax.numpy as jnp
 11import jax.tree_util as jtu
 12import optax
 13
 14from evorl.distributed import agent_gradient_update, psum
 15from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
 16from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
 17from evorl.evaluators import Evaluator
 18from evorl.metrics import TrainMetric, MetricBase
 19from evorl.networks import make_policy_network, make_v_network
 20from evorl.rollout import rollout
 21from evorl.sample_batch import SampleBatch
 22from evorl.types import (
 23    MISSING_REWARD,
 24    Action,
 25    LossDict,
 26    Params,
 27    PolicyExtraInfo,
 28    PyTreeData,
 29    PyTreeDict,
 30    State,
 31    pytree_field,
 32)
 33from evorl.utils import running_statistics
 34from evorl.utils.jax_utils import (
 35    tree_get,
 36    tree_stop_gradient,
 37    scan_and_mean,
 38)
 39from evorl.utils.rl_toolkits import (
 40    average_episode_discount_return,
 41    compute_gae_with_horizon,
 42    flatten_rollout_trajectory,
 43    approximate_kl,
 44)
 45from evorl.workflows import OnPolicyWorkflow
 46from evorl.recorders import add_prefix
 47from evorl.agent import Agent, AgentState
 48
 49logger = logging.getLogger(__name__)
 50
 51
[docs] 52class PPONetworkParams(PyTreeData): 53 """Contains training state for the learner.""" 54 55 policy_params: Params 56 value_params: Params
57 58
[docs] 59class PPOAgent(Agent): 60 continuous_action: bool 61 policy_network: nn.Module 62 value_network: nn.Module 63 obs_preprocessor: Any = pytree_field(default=None, static=True) 64 65 clip_epsilon: float = 0.2 66 normalize_gae: bool = True 67 policy_obs_key: str = "" 68 value_obs_key: str = "" 69 70 @property 71 def normalize_obs(self): 72 return self.obs_preprocessor is not None 73
[docs] 74 def init( 75 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 76 ) -> AgentState: 77 policy_key, value_key = jax.random.split(key, 2) 78 79 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 80 81 policy_params = self.policy_network.init(policy_key, dummy_obs) 82 83 value_params = self.value_network.init(value_key, dummy_obs) 84 85 params_state = PPONetworkParams( 86 policy_params=policy_params, value_params=value_params 87 ) 88 89 if self.normalize_obs: 90 # Note: statistics are broadcasted to [T*B] 91 obs_preprocessor_state = running_statistics.init_state( 92 tree_get(dummy_obs, 0) 93 ) 94 else: 95 obs_preprocessor_state = None 96 97 return AgentState( 98 params=params_state, obs_preprocessor_state=obs_preprocessor_state 99 )
100
[docs] 101 def compute_actions( 102 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 103 ) -> tuple[Action, PolicyExtraInfo]: 104 obs = sample_batch.obs 105 if self.normalize_obs: 106 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 107 108 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 109 110 if self.continuous_action: 111 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 112 else: 113 actions_dist = get_categorical_dist(raw_actions) 114 115 actions = actions_dist.sample(seed=key) 116 117 policy_extras = PyTreeDict( 118 # raw_action=raw_actions, 119 logp=actions_dist.log_prob(actions) 120 ) 121 122 return actions, policy_extras
123
[docs] 124 def evaluate_actions( 125 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 126 ) -> tuple[Action, PolicyExtraInfo]: 127 obs = sample_batch.obs 128 if self.normalize_obs: 129 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 130 131 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 132 133 if self.continuous_action: 134 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 135 else: 136 actions_dist = get_categorical_dist(raw_actions) 137 138 actions = actions_dist.mode() 139 140 return actions, PyTreeDict()
141
[docs] 142 def loss( 143 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 144 ) -> LossDict: 145 obs = sample_batch.obs 146 if self.normalize_obs: 147 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 148 149 # mask invalid transitions at autoreset 150 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset) 151 152 # ======= critic ======= 153 vs = self.value_network.apply(agent_state.params.value_params, obs) 154 155 v_targets = sample_batch.extras.v_targets 156 157 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask) 158 159 # ====== actor ======= 160 161 # [T*B, A] 162 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 163 164 if self.continuous_action: 165 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 166 else: 167 actions_dist = get_categorical_dist(raw_actions) 168 169 # [T*B] 170 actions_logp = actions_dist.log_prob(sample_batch.actions) 171 behavior_actions_logp = sample_batch.extras.policy_extras.logp 172 173 advantages = sample_batch.extras.advantages 174 if self.normalize_gae: 175 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 176 177 logrho = actions_logp - behavior_actions_logp 178 rho = jnp.exp(logrho) 179 180 # advantages: [T*B] 181 policy_sorrogate_loss1 = rho * advantages 182 policy_sorrogate_loss2 = ( 183 jnp.clip(rho, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages 184 ) 185 actor_loss = -jnp.minimum(policy_sorrogate_loss1, policy_sorrogate_loss2).mean( 186 where=mask 187 ) 188 189 # entropy: [T*B] 190 if self.continuous_action: 191 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask) 192 else: 193 actor_entropy = actions_dist.entropy().mean(where=mask) 194 195 approx_kl = approximate_kl(logrho) 196 197 return PyTreeDict( 198 actor_loss=actor_loss, 199 critic_loss=critic_loss, 200 actor_entropy=actor_entropy, 201 approx_kl=approx_kl, 202 )
203
[docs] 204 def compute_values( 205 self, agent_state: AgentState, sample_batch: SampleBatch 206 ) -> chex.Array: 207 obs = sample_batch.obs 208 if self.normalize_obs: 209 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 210 211 return self.value_network.apply(agent_state.params.value_params, obs)
212 213
[docs] 214def make_mlp_ppo_agent( 215 action_space: Space, 216 clip_epsilon: float = 0.2, 217 actor_hidden_layer_sizes: tuple[int] = (256, 256), 218 critic_hidden_layer_sizes: tuple[int] = (256, 256), 219 normalize_obs: bool = False, 220 normalize_gae: bool = False, 221 policy_obs_key: str = "", 222 value_obs_key: str = "", 223): 224 if isinstance(action_space, Box): 225 action_size = action_space.shape[0] * 2 226 continuous_action = True 227 elif isinstance(action_space, Discrete): 228 action_size = action_space.n 229 continuous_action = False 230 else: 231 raise NotImplementedError(f"Unsupported action space: {action_space}") 232 233 policy_network = make_policy_network( 234 action_size=action_size, 235 hidden_layer_sizes=actor_hidden_layer_sizes, 236 obs_key=policy_obs_key, 237 ) 238 239 value_network = make_v_network( 240 hidden_layer_sizes=critic_hidden_layer_sizes, 241 obs_key=value_obs_key, 242 ) 243 244 if normalize_obs: 245 obs_preprocessor = running_statistics.normalize 246 else: 247 obs_preprocessor = None 248 249 return PPOAgent( 250 continuous_action=continuous_action, 251 policy_network=policy_network, 252 value_network=value_network, 253 obs_preprocessor=obs_preprocessor, 254 clip_epsilon=clip_epsilon, 255 normalize_gae=normalize_gae, 256 policy_obs_key=policy_obs_key, 257 value_obs_key=value_obs_key, 258 )
259 260
[docs] 261class PPOWorkflow(OnPolicyWorkflow):
[docs] 262 @classmethod 263 def name(cls): 264 return "PPO"
265 266 @classmethod 267 def _rescale_config(cls, config: DictConfig) -> None: 268 num_devices = jax.device_count() 269 270 if config.num_envs % num_devices != 0: 271 logger.warning( 272 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), " 273 f"rescale num_envs to {config.num_envs // num_devices}" 274 ) 275 if config.num_eval_envs % num_devices != 0: 276 logger.warning( 277 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), " 278 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}" 279 ) 280 if config.minibatch_size % num_devices != 0: 281 logger.warning( 282 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), " 283 f"rescale minibatch_size to {config.minibatch_size // num_devices}" 284 ) 285 286 config.num_envs = config.num_envs // num_devices 287 config.num_eval_envs = config.num_eval_envs // num_devices 288 config.minibatch_size = config.minibatch_size // num_devices 289 290 @classmethod 291 def _build_from_config(cls, config: DictConfig): 292 max_episode_steps = config.env.max_episode_steps 293 294 env = create_env( 295 config.env, 296 episode_length=max_episode_steps, 297 parallel=config.num_envs, 298 autoreset_mode=AutoresetMode.ENVPOOL, 299 ) 300 301 agent = make_mlp_ppo_agent( 302 action_space=env.action_space, 303 clip_epsilon=config.clip_epsilon, 304 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 305 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 306 normalize_obs=config.normalize_obs, 307 normalize_gae=config.normalize_gae, 308 policy_obs_key=config.agent_network.policy_obs_key, 309 value_obs_key=config.agent_network.value_obs_key, 310 ) 311 312 if ( 313 config.optimizer.grad_clip_norm is not None 314 and config.optimizer.grad_clip_norm > 0 315 ): 316 optimizer = optax.chain( 317 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 318 optax.adam(config.optimizer.lr), 319 ) 320 else: 321 optimizer = optax.adam(config.optimizer.lr) 322 323 eval_env = create_env( 324 config.env, 325 episode_length=max_episode_steps, 326 parallel=config.num_eval_envs, 327 autoreset_mode=AutoresetMode.DISABLED, 328 ) 329 330 one_step_rollout_steps = config.num_envs * config.rollout_length 331 if one_step_rollout_steps % config.minibatch_size != 0: 332 logger.warning( 333 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)" 334 ) 335 336 evaluator = Evaluator( 337 env=eval_env, 338 action_fn=agent.evaluate_actions, 339 max_episode_steps=max_episode_steps, 340 ) 341 342 return cls(env, agent, optimizer, evaluator, config) 343
[docs] 344 def step(self, state: State) -> tuple[MetricBase, State]: 345 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 346 347 # trajectory: [T, #envs, ...] 348 trajectory, env_state = rollout( 349 self.env.step, 350 self.agent.compute_actions, 351 state.env_state, 352 state.agent_state, 353 rollout_key, 354 rollout_length=self.config.rollout_length, 355 env_extra_fields=("autoreset", "episode_return", "termination"), 356 ) 357 358 agent_state = state.agent_state 359 if agent_state.obs_preprocessor_state is not None: 360 agent_state = agent_state.replace( 361 obs_preprocessor_state=running_statistics.update( 362 agent_state.obs_preprocessor_state, 363 trajectory.obs, 364 dp_axis_name=self.dp_axis_name, 365 ) 366 ) 367 368 train_episode_return = average_episode_discount_return( 369 trajectory.extras.env_extras.episode_return, 370 trajectory.dones, 371 dp_axis_name=self.dp_axis_name, 372 ) 373 374 # ======== compute GAE ======= 375 _obs = jtu.tree_map( 376 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0), 377 trajectory.obs, 378 trajectory.next_obs, 379 ) 380 # concat [values, bootstrap_value] 381 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs)) 382 383 v_targets, advantages = compute_gae_with_horizon( 384 rewards=trajectory.rewards, 385 values=vs, 386 dones=trajectory.dones, 387 terminations=trajectory.extras.env_extras.termination, 388 gae_horizon=self.config.gae_horizon, 389 gae_lambda=self.config.gae_lambda, 390 discount=self.config.discount, 391 ) 392 393 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets) 394 trajectory.extras.advantages = jax.lax.stop_gradient(advantages) 395 # [T,B,...] -> [T*B,...] 396 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory)) 397 # ============================ 398 399 def loss_fn(agent_state, sample_batch, key): 400 # learn all data from trajectory 401 loss_dict = self.agent.loss(agent_state, sample_batch, key) 402 loss_weights = self.config.loss_weights 403 loss = jnp.zeros(()) 404 for loss_key in loss_weights.keys(): 405 loss += loss_weights[loss_key] * loss_dict[loss_key] 406 407 return loss, loss_dict 408 409 update_fn = agent_gradient_update( 410 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True 411 ) 412 413 num_minibatches = ( 414 self.config.rollout_length 415 * self.config.num_envs 416 // self.config.minibatch_size 417 ) 418 419 def _get_shuffled_minibatch(perm_key, x): 420 x = x[jax.random.permutation(perm_key, x.shape[0])][ 421 : num_minibatches * self.config.minibatch_size 422 ] 423 return x.reshape(num_minibatches, self.config.minibatch_size, *x.shape[1:]) 424 425 def minibatch_step(carry, trajectory): 426 opt_state, agent_state, key = carry 427 key, learn_key = jax.random.split(key) 428 429 (loss, loss_dict), agent_state, opt_state = update_fn( 430 opt_state, agent_state, trajectory, learn_key 431 ) 432 433 return (opt_state, agent_state, key), (loss, loss_dict) 434 435 def epoch_step(carry, _): 436 opt_state, agent_state, key = carry 437 perm_key, learn_key = jax.random.split(key, num=2) 438 439 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean( 440 minibatch_step, 441 (opt_state, agent_state, learn_key), 442 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory), 443 length=num_minibatches, 444 ) 445 446 return (opt_state, agent_state, key), (loss, loss_dict) 447 448 # loss_list: [reuse_rollout_epochs, num_minibatches] 449 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean( 450 epoch_step, 451 (state.opt_state, agent_state, learn_key), 452 None, 453 length=self.config.reuse_rollout_epochs, 454 ) 455 456 # ======== update metrics ======== 457 458 sampled_timesteps = psum( 459 jnp.uint32(self.config.rollout_length * self.config.num_envs), 460 axis_name=self.dp_axis_name, 461 ) 462 463 sampled_epsiodes = psum( 464 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 465 ) 466 467 workflow_metrics = state.metrics.replace( 468 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 469 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 470 iterations=state.metrics.iterations + 1, 471 ).all_reduce(dp_axis_name=self.dp_axis_name) 472 473 train_metrics = TrainMetric( 474 train_episode_return=train_episode_return, 475 loss=loss, 476 raw_loss_dict=loss_dict, 477 ).all_reduce(dp_axis_name=self.dp_axis_name) 478 479 return train_metrics, state.replace( 480 key=key, 481 metrics=workflow_metrics, 482 agent_state=agent_state, 483 env_state=env_state, 484 opt_state=opt_state, 485 )
486
[docs] 487 def learn(self, state: State) -> State: 488 one_step_timesteps = self.config.rollout_length * self.config.num_envs 489 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps) 490 491 start_iteration = state.metrics.iterations 492 493 for i in range(start_iteration, num_iters): 494 train_metrics, state = self.step(state) 495 workflow_metrics = state.metrics 496 497 iters = i + 1 498 499 self.recorder.write(workflow_metrics.to_local_dict(), iters) 500 train_metric_data = train_metrics.to_local_dict() 501 if train_metrics.train_episode_return == MISSING_REWARD: 502 train_metric_data["train_episode_return"] = None 503 self.recorder.write(train_metric_data, iters) 504 505 if iters % self.config.eval_interval == 0 or iters == num_iters: 506 eval_metrics, state = self.evaluate(state) 507 self.recorder.write( 508 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 509 ) 510 511 self.checkpoint_manager.save( 512 iters, 513 state, 514 force=iters == num_iters, 515 ) 516 517 return state