Source code for evorl.algorithms.meta.pbt_ppo.param_ppo

  1import logging
  2from functools import partial
  3from omegaconf import DictConfig
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10
 11from evorl.agent import AgentState
 12from evorl.envs import Space, create_env, AutoresetMode, Box, Discrete
 13from evorl.evaluators import Evaluator
 14from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
 15from evorl.distributed import agent_gradient_update, psum
 16from evorl.metrics import TrainMetric, MetricBase
 17from evorl.networks import make_policy_network, make_v_network
 18from evorl.rollout import rollout
 19from evorl.sample_batch import SampleBatch
 20from evorl.types import PyTreeDict, State, LossDict
 21from evorl.utils import running_statistics
 22from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean
 23from evorl.utils.rl_toolkits import (
 24    average_episode_discount_return,
 25    compute_gae,
 26    flatten_rollout_trajectory,
 27    approximate_kl,
 28)
 29
 30from evorl.algorithms.ppo import PPOWorkflow, PPOAgent
 31
 32logger = logging.getLogger(__name__)
 33
 34
[docs] 35class ParamPPOAgent(PPOAgent):
[docs] 36 def init( 37 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 38 ) -> AgentState: 39 agent_state = super().init(obs_space, action_space, key) 40 41 return agent_state.replace( 42 extra_state=PyTreeDict(clip_epsilon=jnp.float32(self.clip_epsilon)) 43 )
44
[docs] 45 def loss( 46 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 47 ) -> LossDict: 48 obs = sample_batch.obs 49 if self.normalize_obs: 50 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 51 52 # mask invalid transitions at autoreset 53 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset) 54 55 # ======= critic ======= 56 vs = self.value_network.apply(agent_state.params.value_params, obs) 57 58 v_targets = sample_batch.extras.v_targets 59 60 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask) 61 62 # ====== actor ======= 63 64 # [T*B, A] 65 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 66 67 if self.continuous_action: 68 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 69 else: 70 actions_dist = get_categorical_dist(raw_actions) 71 72 # [T*B] 73 actions_logp = actions_dist.log_prob(sample_batch.actions) 74 behavior_actions_logp = sample_batch.extras.policy_extras.logp 75 76 advantages = sample_batch.extras.advantages 77 if self.normalize_gae: 78 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 79 80 logrho = actions_logp - behavior_actions_logp 81 rho = jnp.exp(logrho) 82 83 # advantages: [T*B] 84 policy_sorrogate_loss1 = rho * advantages 85 policy_sorrogate_loss2 = ( 86 jnp.clip( 87 rho, 88 1 - agent_state.extra_state.clip_epsilon, 89 1 + agent_state.extra_state.clip_epsilon, 90 ) 91 * advantages 92 ) 93 actor_loss = -jnp.minimum(policy_sorrogate_loss1, policy_sorrogate_loss2).mean( 94 where=mask 95 ) 96 97 # entropy: [T*B] 98 if self.continuous_action: 99 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask) 100 else: 101 actor_entropy = actions_dist.entropy().mean(where=mask) 102 103 approx_kl = approximate_kl(logrho) 104 105 return PyTreeDict( 106 actor_loss=actor_loss, 107 critic_loss=critic_loss, 108 actor_entropy=actor_entropy, 109 approx_kl=approx_kl, 110 )
111 112
[docs] 113def make_mlp_ppo_agent( 114 action_space: Space, 115 clip_epsilon: float = 0.2, 116 actor_hidden_layer_sizes: tuple[int] = (256, 256), 117 critic_hidden_layer_sizes: tuple[int] = (256, 256), 118 normalize_obs: bool = False, 119): 120 if isinstance(action_space, Box): 121 action_size = action_space.shape[0] * 2 122 continuous_action = True 123 elif isinstance(action_space, Discrete): 124 action_size = action_space.n 125 continuous_action = False 126 else: 127 raise NotImplementedError(f"Unsupported action space: {action_space}") 128 129 policy_network = make_policy_network( 130 action_size=action_size, 131 hidden_layer_sizes=actor_hidden_layer_sizes, 132 ) 133 134 value_network = make_v_network(hidden_layer_sizes=critic_hidden_layer_sizes) 135 136 if normalize_obs: 137 obs_preprocessor = running_statistics.normalize 138 else: 139 obs_preprocessor = None 140 141 return ParamPPOAgent( 142 continuous_action=continuous_action, 143 policy_network=policy_network, 144 value_network=value_network, 145 obs_preprocessor=obs_preprocessor, 146 clip_epsilon=clip_epsilon, 147 )
148 149
[docs] 150class ParamPPOWorkflow(PPOWorkflow):
[docs] 151 @classmethod 152 def name(cls): 153 return "ParamPPO"
154 155 @classmethod 156 def _build_from_config(cls, config: DictConfig): 157 max_episode_steps = config.env.max_episode_steps 158 159 env = create_env( 160 config.env, 161 episode_length=max_episode_steps, 162 parallel=config.num_envs, 163 autoreset_mode=AutoresetMode.ENVPOOL, 164 ) 165 166 agent = make_mlp_ppo_agent( 167 action_space=env.action_space, 168 clip_epsilon=config.clip_epsilon, 169 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 170 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 171 normalize_obs=config.normalize_obs, 172 ) 173 174 if ( 175 config.optimizer.grad_clip_norm is not None 176 and config.optimizer.grad_clip_norm > 0 177 ): 178 optimizer = optax.chain( 179 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 180 optax.adam(config.optimizer.lr), 181 ) 182 else: 183 optimizer = optax.adam(config.optimizer.lr) 184 185 eval_env = create_env( 186 config.env, 187 episode_length=max_episode_steps, 188 parallel=config.num_eval_envs, 189 autoreset_mode=AutoresetMode.DISABLED, 190 ) 191 192 one_step_rollout_steps = config.num_envs * config.rollout_length 193 if one_step_rollout_steps % config.minibatch_size != 0: 194 logger.warning( 195 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)" 196 ) 197 198 evaluator = Evaluator( 199 env=eval_env, 200 action_fn=agent.evaluate_actions, 201 max_episode_steps=max_episode_steps, 202 ) 203 204 return cls(env, agent, optimizer, evaluator, config) 205
[docs] 206 def setup(self, key: chex.PRNGKey) -> State: 207 state = super().setup(key) 208 209 return state.replace( 210 hp_state=PyTreeDict( 211 gae_lambda_g=-jnp.log(1 - jnp.float32(self.config.gae_lambda)), 212 discount_g=-jnp.log( 213 1 - jnp.float32(self.config.discount) 214 ), # discount = 1 - exp(-g) 215 actor_loss_weight=jnp.float32(self.config.loss_weights.actor_loss), 216 critic_loss_weight=jnp.float32(self.config.loss_weights.critic_loss), 217 entropy_loss_weight=jnp.float32(self.config.loss_weights.actor_entropy), 218 ) 219 )
220
[docs] 221 def step(self, state: State) -> tuple[MetricBase, State]: 222 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 223 224 # trajectory: [T, #envs, ...] 225 trajectory, env_state = rollout( 226 self.env.step, 227 self.agent.compute_actions, 228 state.env_state, 229 state.agent_state, 230 rollout_key, 231 rollout_length=self.config.rollout_length, 232 env_extra_fields=("autoreset", "episode_return", "termination"), 233 ) 234 235 agent_state = state.agent_state 236 if agent_state.obs_preprocessor_state is not None: 237 agent_state = agent_state.replace( 238 obs_preprocessor_state=running_statistics.update( 239 agent_state.obs_preprocessor_state, 240 trajectory.obs, 241 dp_axis_name=self.dp_axis_name, 242 ) 243 ) 244 245 train_episode_return = average_episode_discount_return( 246 trajectory.extras.env_extras.episode_return, 247 trajectory.dones, 248 dp_axis_name=self.dp_axis_name, 249 ) 250 251 # ======== compute GAE ======= 252 _obs = jtu.tree_map( 253 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0), 254 trajectory.obs, 255 trajectory.next_obs, 256 ) 257 # concat [values, bootstrap_value] 258 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs)) 259 260 gae_lambda = 1 - jnp.exp(-state.hp_state.gae_lambda_g) 261 discount = 1 - jnp.exp(-state.hp_state.discount_g) 262 263 v_targets, advantages = compute_gae( 264 rewards=trajectory.rewards, # peb_rewards 265 values=vs, 266 dones=trajectory.dones, 267 terminations=trajectory.extras.env_extras.termination, 268 gae_lambda=gae_lambda, 269 discount=discount, 270 ) 271 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets) 272 trajectory.extras.advantages = jax.lax.stop_gradient(advantages) 273 # [T,B,...] -> [T*B,...] 274 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory)) 275 # ============================ 276 277 def loss_fn(agent_state, sample_batch, key): 278 # learn all data from trajectory 279 loss_dict = self.agent.loss(agent_state, sample_batch, key) 280 loss_weights = dict( 281 actor_loss=state.hp_state.actor_loss_weight, 282 critic_loss=state.hp_state.critic_loss_weight, 283 actor_entropy=state.hp_state.entropy_loss_weight, 284 ) 285 loss = jnp.zeros(()) 286 for loss_key in loss_weights.keys(): 287 loss += loss_weights[loss_key] * loss_dict[loss_key] 288 289 return loss, loss_dict 290 291 update_fn = agent_gradient_update( 292 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True 293 ) 294 295 num_minibatches = ( 296 self.config.rollout_length 297 * self.config.num_envs 298 // self.config.minibatch_size 299 ) 300 301 def _get_shuffled_minibatch(perm_key, x): 302 x = x[jax.random.permutation(perm_key, x.shape[0])][ 303 : num_minibatches * self.config.minibatch_size 304 ] 305 return x.reshape(num_minibatches, self.config.minibatch_size, *x.shape[1:]) 306 307 def minibatch_step(carry, trajectory): 308 opt_state, agent_state, key = carry 309 key, learn_key = jax.random.split(key) 310 311 (loss, loss_dict), agent_state, opt_state = update_fn( 312 opt_state, agent_state, trajectory, learn_key 313 ) 314 315 return (opt_state, agent_state, key), (loss, loss_dict) 316 317 def epoch_step(carry, _): 318 opt_state, agent_state, key = carry 319 perm_key, learn_key = jax.random.split(key, num=2) 320 321 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean( 322 minibatch_step, 323 (opt_state, agent_state, learn_key), 324 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory), 325 length=num_minibatches, 326 ) 327 328 return (opt_state, agent_state, key), (loss, loss_dict) 329 330 # loss_list: [reuse_rollout_epochs, num_minibatches] 331 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean( 332 epoch_step, 333 (state.opt_state, agent_state, learn_key), 334 None, 335 length=self.config.reuse_rollout_epochs, 336 ) 337 338 # ======== update metrics ======== 339 340 sampled_timesteps = psum( 341 jnp.uint32(self.config.rollout_length * self.config.num_envs), 342 axis_name=self.dp_axis_name, 343 ) 344 sampled_epsiodes = psum( 345 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 346 ) 347 348 workflow_metrics = state.metrics.replace( 349 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 350 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 351 iterations=state.metrics.iterations + 1, 352 ).all_reduce(dp_axis_name=self.dp_axis_name) 353 354 train_metrics = TrainMetric( 355 train_episode_return=train_episode_return, 356 loss=loss, 357 raw_loss_dict=loss_dict, 358 ).all_reduce(dp_axis_name=self.dp_axis_name) 359 360 return train_metrics, state.replace( 361 key=key, 362 metrics=workflow_metrics, 363 agent_state=agent_state, 364 env_state=env_state, 365 opt_state=opt_state, 366 )