1import logging
2import math
3from functools import partial
4from typing import Any
5from omegaconf import DictConfig
6
7import chex
8import flax.linen as nn
9import jax
10import jax.numpy as jnp
11import jax.tree_util as jtu
12import optax
13
14from evorl.distributed import agent_gradient_update, psum
15from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
16from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
17from evorl.evaluators import Evaluator
18from evorl.metrics import TrainMetric, MetricBase
19from evorl.networks import make_policy_network, make_v_network
20from evorl.rollout import rollout
21from evorl.sample_batch import SampleBatch
22from evorl.types import (
23 MISSING_REWARD,
24 Action,
25 LossDict,
26 Params,
27 PolicyExtraInfo,
28 PyTreeData,
29 PyTreeDict,
30 State,
31 pytree_field,
32)
33from evorl.utils import running_statistics
34from evorl.utils.jax_utils import (
35 tree_get,
36 tree_stop_gradient,
37 scan_and_mean,
38)
39from evorl.utils.rl_toolkits import (
40 average_episode_discount_return,
41 compute_gae_with_horizon,
42 flatten_rollout_trajectory,
43 approximate_kl,
44)
45from evorl.workflows import OnPolicyWorkflow
46from evorl.recorders import add_prefix
47from evorl.agent import Agent, AgentState
48
49logger = logging.getLogger(__name__)
50
51
[docs]
52class PPONetworkParams(PyTreeData):
53 """Contains training state for the learner."""
54
55 policy_params: Params
56 value_params: Params
57
58
[docs]
59class PPOAgent(Agent):
60 continuous_action: bool
61 policy_network: nn.Module
62 value_network: nn.Module
63 obs_preprocessor: Any = pytree_field(default=None, static=True)
64
65 clip_epsilon: float = 0.2
66 normalize_gae: bool = True
67 policy_obs_key: str = ""
68 value_obs_key: str = ""
69
70 @property
71 def normalize_obs(self):
72 return self.obs_preprocessor is not None
73
[docs]
74 def init(
75 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
76 ) -> AgentState:
77 policy_key, value_key = jax.random.split(key, 2)
78
79 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
80
81 policy_params = self.policy_network.init(policy_key, dummy_obs)
82
83 value_params = self.value_network.init(value_key, dummy_obs)
84
85 params_state = PPONetworkParams(
86 policy_params=policy_params, value_params=value_params
87 )
88
89 if self.normalize_obs:
90 # Note: statistics are broadcasted to [T*B]
91 obs_preprocessor_state = running_statistics.init_state(
92 tree_get(dummy_obs, 0)
93 )
94 else:
95 obs_preprocessor_state = None
96
97 return AgentState(
98 params=params_state, obs_preprocessor_state=obs_preprocessor_state
99 )
100
[docs]
101 def compute_actions(
102 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
103 ) -> tuple[Action, PolicyExtraInfo]:
104 obs = sample_batch.obs
105 if self.normalize_obs:
106 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
107
108 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
109
110 if self.continuous_action:
111 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
112 else:
113 actions_dist = get_categorical_dist(raw_actions)
114
115 actions = actions_dist.sample(seed=key)
116
117 policy_extras = PyTreeDict(
118 # raw_action=raw_actions,
119 logp=actions_dist.log_prob(actions)
120 )
121
122 return actions, policy_extras
123
[docs]
124 def evaluate_actions(
125 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
126 ) -> tuple[Action, PolicyExtraInfo]:
127 obs = sample_batch.obs
128 if self.normalize_obs:
129 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
130
131 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
132
133 if self.continuous_action:
134 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
135 else:
136 actions_dist = get_categorical_dist(raw_actions)
137
138 actions = actions_dist.mode()
139
140 return actions, PyTreeDict()
141
[docs]
142 def loss(
143 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
144 ) -> LossDict:
145 obs = sample_batch.obs
146 if self.normalize_obs:
147 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
148
149 # mask invalid transitions at autoreset
150 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset)
151
152 # ======= critic =======
153 vs = self.value_network.apply(agent_state.params.value_params, obs)
154
155 v_targets = sample_batch.extras.v_targets
156
157 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask)
158
159 # ====== actor =======
160
161 # [T*B, A]
162 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
163
164 if self.continuous_action:
165 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
166 else:
167 actions_dist = get_categorical_dist(raw_actions)
168
169 # [T*B]
170 actions_logp = actions_dist.log_prob(sample_batch.actions)
171 behavior_actions_logp = sample_batch.extras.policy_extras.logp
172
173 advantages = sample_batch.extras.advantages
174 if self.normalize_gae:
175 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
176
177 logrho = actions_logp - behavior_actions_logp
178 rho = jnp.exp(logrho)
179
180 # advantages: [T*B]
181 policy_sorrogate_loss1 = rho * advantages
182 policy_sorrogate_loss2 = (
183 jnp.clip(rho, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
184 )
185 actor_loss = -jnp.minimum(policy_sorrogate_loss1, policy_sorrogate_loss2).mean(
186 where=mask
187 )
188
189 # entropy: [T*B]
190 if self.continuous_action:
191 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask)
192 else:
193 actor_entropy = actions_dist.entropy().mean(where=mask)
194
195 approx_kl = approximate_kl(logrho)
196
197 return PyTreeDict(
198 actor_loss=actor_loss,
199 critic_loss=critic_loss,
200 actor_entropy=actor_entropy,
201 approx_kl=approx_kl,
202 )
203
[docs]
204 def compute_values(
205 self, agent_state: AgentState, sample_batch: SampleBatch
206 ) -> chex.Array:
207 obs = sample_batch.obs
208 if self.normalize_obs:
209 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
210
211 return self.value_network.apply(agent_state.params.value_params, obs)
212
213
[docs]
214def make_mlp_ppo_agent(
215 action_space: Space,
216 clip_epsilon: float = 0.2,
217 actor_hidden_layer_sizes: tuple[int] = (256, 256),
218 critic_hidden_layer_sizes: tuple[int] = (256, 256),
219 normalize_obs: bool = False,
220 normalize_gae: bool = False,
221 policy_obs_key: str = "",
222 value_obs_key: str = "",
223):
224 if isinstance(action_space, Box):
225 action_size = action_space.shape[0] * 2
226 continuous_action = True
227 elif isinstance(action_space, Discrete):
228 action_size = action_space.n
229 continuous_action = False
230 else:
231 raise NotImplementedError(f"Unsupported action space: {action_space}")
232
233 policy_network = make_policy_network(
234 action_size=action_size,
235 hidden_layer_sizes=actor_hidden_layer_sizes,
236 obs_key=policy_obs_key,
237 )
238
239 value_network = make_v_network(
240 hidden_layer_sizes=critic_hidden_layer_sizes,
241 obs_key=value_obs_key,
242 )
243
244 if normalize_obs:
245 obs_preprocessor = running_statistics.normalize
246 else:
247 obs_preprocessor = None
248
249 return PPOAgent(
250 continuous_action=continuous_action,
251 policy_network=policy_network,
252 value_network=value_network,
253 obs_preprocessor=obs_preprocessor,
254 clip_epsilon=clip_epsilon,
255 normalize_gae=normalize_gae,
256 policy_obs_key=policy_obs_key,
257 value_obs_key=value_obs_key,
258 )
259
260
[docs]
261class PPOWorkflow(OnPolicyWorkflow):
[docs]
262 @classmethod
263 def name(cls):
264 return "PPO"
265
266 @classmethod
267 def _rescale_config(cls, config: DictConfig) -> None:
268 num_devices = jax.device_count()
269
270 if config.num_envs % num_devices != 0:
271 logger.warning(
272 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), "
273 f"rescale num_envs to {config.num_envs // num_devices}"
274 )
275 if config.num_eval_envs % num_devices != 0:
276 logger.warning(
277 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), "
278 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}"
279 )
280 if config.minibatch_size % num_devices != 0:
281 logger.warning(
282 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), "
283 f"rescale minibatch_size to {config.minibatch_size // num_devices}"
284 )
285
286 config.num_envs = config.num_envs // num_devices
287 config.num_eval_envs = config.num_eval_envs // num_devices
288 config.minibatch_size = config.minibatch_size // num_devices
289
290 @classmethod
291 def _build_from_config(cls, config: DictConfig):
292 max_episode_steps = config.env.max_episode_steps
293
294 env = create_env(
295 config.env,
296 episode_length=max_episode_steps,
297 parallel=config.num_envs,
298 autoreset_mode=AutoresetMode.ENVPOOL,
299 )
300
301 agent = make_mlp_ppo_agent(
302 action_space=env.action_space,
303 clip_epsilon=config.clip_epsilon,
304 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
305 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
306 normalize_obs=config.normalize_obs,
307 normalize_gae=config.normalize_gae,
308 policy_obs_key=config.agent_network.policy_obs_key,
309 value_obs_key=config.agent_network.value_obs_key,
310 )
311
312 if (
313 config.optimizer.grad_clip_norm is not None
314 and config.optimizer.grad_clip_norm > 0
315 ):
316 optimizer = optax.chain(
317 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
318 optax.adam(config.optimizer.lr),
319 )
320 else:
321 optimizer = optax.adam(config.optimizer.lr)
322
323 eval_env = create_env(
324 config.env,
325 episode_length=max_episode_steps,
326 parallel=config.num_eval_envs,
327 autoreset_mode=AutoresetMode.DISABLED,
328 )
329
330 one_step_rollout_steps = config.num_envs * config.rollout_length
331 if one_step_rollout_steps % config.minibatch_size != 0:
332 logger.warning(
333 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)"
334 )
335
336 evaluator = Evaluator(
337 env=eval_env,
338 action_fn=agent.evaluate_actions,
339 max_episode_steps=max_episode_steps,
340 )
341
342 return cls(env, agent, optimizer, evaluator, config)
343
[docs]
344 def step(self, state: State) -> tuple[MetricBase, State]:
345 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
346
347 # trajectory: [T, #envs, ...]
348 trajectory, env_state = rollout(
349 self.env.step,
350 self.agent.compute_actions,
351 state.env_state,
352 state.agent_state,
353 rollout_key,
354 rollout_length=self.config.rollout_length,
355 env_extra_fields=("autoreset", "episode_return", "termination"),
356 )
357
358 agent_state = state.agent_state
359 if agent_state.obs_preprocessor_state is not None:
360 agent_state = agent_state.replace(
361 obs_preprocessor_state=running_statistics.update(
362 agent_state.obs_preprocessor_state,
363 trajectory.obs,
364 dp_axis_name=self.dp_axis_name,
365 )
366 )
367
368 train_episode_return = average_episode_discount_return(
369 trajectory.extras.env_extras.episode_return,
370 trajectory.dones,
371 dp_axis_name=self.dp_axis_name,
372 )
373
374 # ======== compute GAE =======
375 _obs = jtu.tree_map(
376 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0),
377 trajectory.obs,
378 trajectory.next_obs,
379 )
380 # concat [values, bootstrap_value]
381 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs))
382
383 v_targets, advantages = compute_gae_with_horizon(
384 rewards=trajectory.rewards,
385 values=vs,
386 dones=trajectory.dones,
387 terminations=trajectory.extras.env_extras.termination,
388 gae_horizon=self.config.gae_horizon,
389 gae_lambda=self.config.gae_lambda,
390 discount=self.config.discount,
391 )
392
393 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets)
394 trajectory.extras.advantages = jax.lax.stop_gradient(advantages)
395 # [T,B,...] -> [T*B,...]
396 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory))
397 # ============================
398
399 def loss_fn(agent_state, sample_batch, key):
400 # learn all data from trajectory
401 loss_dict = self.agent.loss(agent_state, sample_batch, key)
402 loss_weights = self.config.loss_weights
403 loss = jnp.zeros(())
404 for loss_key in loss_weights.keys():
405 loss += loss_weights[loss_key] * loss_dict[loss_key]
406
407 return loss, loss_dict
408
409 update_fn = agent_gradient_update(
410 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True
411 )
412
413 num_minibatches = (
414 self.config.rollout_length
415 * self.config.num_envs
416 // self.config.minibatch_size
417 )
418
419 def _get_shuffled_minibatch(perm_key, x):
420 x = x[jax.random.permutation(perm_key, x.shape[0])][
421 : num_minibatches * self.config.minibatch_size
422 ]
423 return x.reshape(num_minibatches, self.config.minibatch_size, *x.shape[1:])
424
425 def minibatch_step(carry, trajectory):
426 opt_state, agent_state, key = carry
427 key, learn_key = jax.random.split(key)
428
429 (loss, loss_dict), agent_state, opt_state = update_fn(
430 opt_state, agent_state, trajectory, learn_key
431 )
432
433 return (opt_state, agent_state, key), (loss, loss_dict)
434
435 def epoch_step(carry, _):
436 opt_state, agent_state, key = carry
437 perm_key, learn_key = jax.random.split(key, num=2)
438
439 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean(
440 minibatch_step,
441 (opt_state, agent_state, learn_key),
442 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory),
443 length=num_minibatches,
444 )
445
446 return (opt_state, agent_state, key), (loss, loss_dict)
447
448 # loss_list: [reuse_rollout_epochs, num_minibatches]
449 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean(
450 epoch_step,
451 (state.opt_state, agent_state, learn_key),
452 None,
453 length=self.config.reuse_rollout_epochs,
454 )
455
456 # ======== update metrics ========
457
458 sampled_timesteps = psum(
459 jnp.uint32(self.config.rollout_length * self.config.num_envs),
460 axis_name=self.dp_axis_name,
461 )
462
463 sampled_epsiodes = psum(
464 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
465 )
466
467 workflow_metrics = state.metrics.replace(
468 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
469 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
470 iterations=state.metrics.iterations + 1,
471 ).all_reduce(dp_axis_name=self.dp_axis_name)
472
473 train_metrics = TrainMetric(
474 train_episode_return=train_episode_return,
475 loss=loss,
476 raw_loss_dict=loss_dict,
477 ).all_reduce(dp_axis_name=self.dp_axis_name)
478
479 return train_metrics, state.replace(
480 key=key,
481 metrics=workflow_metrics,
482 agent_state=agent_state,
483 env_state=env_state,
484 opt_state=opt_state,
485 )
486
[docs]
487 def learn(self, state: State) -> State:
488 one_step_timesteps = self.config.rollout_length * self.config.num_envs
489 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps)
490
491 start_iteration = state.metrics.iterations
492
493 for i in range(start_iteration, num_iters):
494 train_metrics, state = self.step(state)
495 workflow_metrics = state.metrics
496
497 iters = i + 1
498
499 self.recorder.write(workflow_metrics.to_local_dict(), iters)
500 train_metric_data = train_metrics.to_local_dict()
501 if train_metrics.train_episode_return == MISSING_REWARD:
502 train_metric_data["train_episode_return"] = None
503 self.recorder.write(train_metric_data, iters)
504
505 if iters % self.config.eval_interval == 0 or iters == num_iters:
506 eval_metrics, state = self.evaluate(state)
507 self.recorder.write(
508 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
509 )
510
511 self.checkpoint_manager.save(
512 iters,
513 state,
514 force=iters == num_iters,
515 )
516
517 return state