1import logging
2import math
3from functools import partial
4from typing import Any
5
6import chex
7import flax.linen as nn
8import jax
9import jax.numpy as jnp
10import jax.tree_util as jtu
11import optax
12from omegaconf import DictConfig
13
14from evorl.distributed import agent_gradient_update, psum
15from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
16from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
17from evorl.evaluators import Evaluator
18from evorl.metrics import TrainMetric, MetricBase
19from evorl.networks import make_policy_network, make_v_network
20from evorl.rollout import rollout
21from evorl.sample_batch import SampleBatch
22from evorl.types import (
23 MISSING_REWARD,
24 Action,
25 LossDict,
26 Params,
27 PolicyExtraInfo,
28 PyTreeData,
29 PyTreeDict,
30 State,
31 pytree_field,
32)
33from evorl.utils import running_statistics
34from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean, tree_get
35from evorl.utils.rl_toolkits import average_episode_discount_return, approximate_kl
36from evorl.workflows import OnPolicyWorkflow
37from evorl.agent import Agent, AgentState
38from evorl.recorders import add_prefix
39
40
41logger = logging.getLogger(__name__)
42
43
[docs]
44class IMPALANetworkParams(PyTreeData):
45 """Contains training state for the learner."""
46
47 policy_params: Params
48 value_params: Params
49
50
51# class IMPALATrainMetric(TrainMetric):
52# rho: chex.Array = jnp.zeros((), dtype=jnp.float32)
53
54
[docs]
55class IMPALAAgent(Agent):
56 continuous_action: bool
57 policy_network: nn.Module
58 value_network: nn.Module
59 obs_preprocessor: Any = pytree_field(default=None, static=True)
60
61 discount: float = 0.99
62 vtrace_lambda: float = 1.0
63 clip_rho_threshold: float = 1.0
64 clip_c_threshold: float = 1.0
65 clip_pg_rho_threshold: float = 1.0
66 adv_mode: str = pytree_field(default="official", static=True)
67
68 @property
69 def normalize_obs(self):
70 return self.obs_preprocessor is not None
71
[docs]
72 def init(
73 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
74 ) -> AgentState:
75 policy_key, value_key = jax.random.split(key, 2)
76
77 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
78
79 policy_params = self.policy_network.init(policy_key, dummy_obs)
80
81 value_params = self.value_network.init(value_key, dummy_obs)
82
83 params_state = IMPALANetworkParams(
84 policy_params=policy_params, value_params=value_params
85 )
86
87 if self.normalize_obs:
88 # Note: statistics are broadcasted to [T*B]
89 obs_preprocessor_state = running_statistics.init_state(
90 tree_get(dummy_obs, 0)
91 )
92 else:
93 obs_preprocessor_state = None
94
95 return AgentState(
96 params=params_state, obs_preprocessor_state=obs_preprocessor_state
97 )
98
[docs]
99 def compute_actions(
100 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
101 ) -> tuple[Action, PolicyExtraInfo]:
102 obs = sample_batch.obs
103 if self.normalize_obs:
104 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
105
106 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
107
108 if self.continuous_action:
109 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
110 else:
111 actions_dist = get_categorical_dist(raw_actions)
112
113 actions = actions_dist.sample(seed=key)
114
115 policy_extras = PyTreeDict(
116 # Log probabilities of the selected actions for importance sampling
117 logp=actions_dist.log_prob(actions)
118 )
119
120 return actions, policy_extras
121
[docs]
122 def evaluate_actions(
123 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
124 ) -> tuple[Action, PolicyExtraInfo]:
125 obs = sample_batch.obs
126 if self.normalize_obs:
127 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
128
129 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
130
131 if self.continuous_action:
132 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
133 else:
134 actions_dist = get_categorical_dist(raw_actions)
135
136 actions = actions_dist.mode()
137
138 return actions, PyTreeDict()
139
[docs]
140 def loss(
141 self, agent_state: AgentState, trajectory: SampleBatch, key: chex.PRNGKey
142 ) -> LossDict:
143 """IMPALA loss.
144
145 Args:
146 trajectory: [T, B, ...]
147 a sequence of transitions, not shuffled timesteps
148
149 """
150 # mask invalid transitions at autoreset
151 mask = jnp.logical_not(trajectory.extras.env_extras.autoreset)
152
153 obs = trajectory.obs
154 _obs = jtu.tree_map(
155 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0),
156 trajectory.obs,
157 trajectory.next_obs,
158 )
159 if self.normalize_obs:
160 _obs = self.obs_preprocessor(_obs, agent_state.obs_preprocessor_state)
161
162 vs = self.value_network.apply(agent_state.params.value_params, _obs)
163
164 behavior_actions_logp = trajectory.extras.policy_extras.logp
165 behavior_actions = trajectory.actions
166
167 # [T, B, A]
168 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
169
170 if self.continuous_action:
171 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
172 else:
173 actions_dist = get_categorical_dist(raw_actions)
174
175 # [T, B]
176 actions_logp = actions_dist.log_prob(behavior_actions)
177 logrho = actions_logp - behavior_actions_logp
178 rho = jnp.exp(logrho)
179
180 # TODO: consider PEB: truncation in the middle of trajectory
181 # hint: use IS of td-error with
182 vtrace = compute_vtrace(
183 rho_t=rho,
184 v_t=vs[:-1],
185 v_t_plus_1=vs[1:],
186 rewards=trajectory.rewards,
187 dones=trajectory.dones,
188 terminations=trajectory.extras.env_extras.termination,
189 discount=self.discount,
190 lambda_=self.vtrace_lambda,
191 clip_rho_threshold=self.clip_rho_threshold,
192 clip_c_threshold=self.clip_c_threshold,
193 )
194
195 vtrace = jax.lax.stop_gradient(vtrace)
196
197 # ======= critic =======
198
199 critic_loss = optax.squared_error(vs[:-1], vtrace).mean(where=mask)
200
201 # ====== actor =======
202
203 # GAE-V: [T*B]
204 pg_advantages = compute_pg_advantage(
205 vtrace=vtrace,
206 v_t=vs[:-1],
207 v_t_plus_1=vs[1:],
208 rewards=trajectory.rewards,
209 terminations=trajectory.extras.env_extras.termination,
210 discount=self.discount,
211 lambda_=self.vtrace_lambda,
212 mode=self.adv_mode,
213 )
214
215 clipped_pg_rho_t = jnp.minimum(self.clip_pg_rho_threshold, rho)
216 pg_advantage = clipped_pg_rho_t * pg_advantages
217 pg_advantage = jax.lax.stop_gradient(pg_advantage)
218
219 policy_loss = -(pg_advantage * actions_logp).mean(where=mask)
220
221 # entropy: [T*B]
222 if self.continuous_action:
223 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask)
224 else:
225 actor_entropy = actions_dist.entropy().mean(where=mask)
226
227 approx_kl = approximate_kl(logrho).mean()
228
229 return PyTreeDict(
230 actor_loss=policy_loss,
231 critic_loss=critic_loss,
232 actor_entropy=actor_entropy,
233 rho=rho.mean(where=mask),
234 approx_kl=approx_kl,
235 )
236
237
[docs]
238def make_mlp_impala_agent(
239 action_space: Space,
240 discount: float = 0.99,
241 vtrace_lambda: float = 1.0,
242 clip_rho_threshold: float = 1.0,
243 clip_c_threshold: float = 1.0,
244 clip_pg_rho_threshold: float = 1.0,
245 adv_mode: str = "official",
246 actor_hidden_layer_sizes: tuple[int] = (256, 256),
247 critic_hidden_layer_sizes: tuple[int] = (256, 256),
248 normalize_obs: bool = False,
249 policy_obs_key: str = "",
250 value_obs_key: str = "",
251):
252 if isinstance(action_space, Box):
253 action_size = action_space.shape[0] * 2
254 continuous_action = True
255 elif isinstance(action_space, Discrete):
256 action_size = action_space.n
257 continuous_action = False
258 else:
259 raise NotImplementedError(f"Unsupported action space: {action_space}")
260
261 policy_network = make_policy_network(
262 action_size=action_size,
263 hidden_layer_sizes=actor_hidden_layer_sizes,
264 obs_key=policy_obs_key,
265 )
266
267 value_network = make_v_network(
268 hidden_layer_sizes=critic_hidden_layer_sizes,
269 obs_key=value_obs_key,
270 )
271
272 if normalize_obs:
273 obs_preprocessor = running_statistics.normalize
274 else:
275 obs_preprocessor = None
276
277 return IMPALAAgent(
278 continuous_action=continuous_action,
279 policy_network=policy_network,
280 value_network=value_network,
281 obs_preprocessor=obs_preprocessor,
282 discount=discount,
283 vtrace_lambda=vtrace_lambda,
284 clip_rho_threshold=clip_rho_threshold,
285 clip_c_threshold=clip_c_threshold,
286 clip_pg_rho_threshold=clip_pg_rho_threshold,
287 adv_mode=adv_mode,
288 )
289
290
[docs]
291class IMPALAWorkflow(OnPolicyWorkflow):
292 """Syncrhonous version of IMPALA (A2C|PPO w/ V-Trace)."""
293
[docs]
294 @classmethod
295 def name(cls):
296 return "IMPALA"
297
298 @classmethod
299 def _rescale_config(cls, config: DictConfig) -> None:
300 num_devices = jax.device_count()
301 if config.num_envs % num_devices != 0:
302 logger.warning(
303 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), "
304 f"rescale num_envs to {config.num_envs // num_devices}"
305 )
306 if config.num_eval_envs % num_devices != 0:
307 logger.warning(
308 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), "
309 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}"
310 )
311 if config.minibatch_size % num_devices != 0:
312 logger.warning(
313 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), "
314 f"rescale minibatch_size to {config.minibatch_size // num_devices}"
315 )
316
317 config.num_envs = config.num_envs // num_devices
318 config.num_eval_envs = config.num_eval_envs // num_devices
319 config.minibatch_size = config.minibatch_size // num_devices
320
321 @classmethod
322 def _build_from_config(cls, config: DictConfig):
323 max_episode_steps = config.env.max_episode_steps
324
325 env = create_env(
326 config.env,
327 episode_length=max_episode_steps,
328 parallel=config.num_envs,
329 autoreset_mode=AutoresetMode.ENVPOOL,
330 )
331
332 # Maybe need a discount array for different agents
333 agent = make_mlp_impala_agent(
334 action_space=env.action_space,
335 discount=config.discount,
336 vtrace_lambda=config.vtrace_lambda,
337 clip_rho_threshold=config.clip_rho_threshold,
338 clip_c_threshold=config.clip_c_threshold,
339 clip_pg_rho_threshold=config.clip_pg_rho_threshold,
340 adv_mode=config.adv_mode,
341 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
342 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
343 normalize_obs=config.normalize_obs,
344 policy_obs_key=config.agent_network.policy_obs_key,
345 value_obs_key=config.agent_network.value_obs_key,
346 )
347
348 if (
349 config.optimizer.grad_clip_norm is not None
350 and config.optimizer.grad_clip_norm > 0
351 ):
352 optimizer = optax.chain(
353 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
354 optax.adam(config.optimizer.lr),
355 )
356 else:
357 optimizer = optax.adam(config.optimizer.lr)
358
359 eval_env = create_env(
360 config.env,
361 episode_length=max_episode_steps,
362 parallel=config.num_eval_envs,
363 autoreset_mode=AutoresetMode.DISABLED,
364 )
365
366 evaluator = Evaluator(
367 env=eval_env,
368 action_fn=agent.evaluate_actions,
369 max_episode_steps=max_episode_steps,
370 )
371
372 return cls(env, agent, optimizer, evaluator, config)
373
[docs]
374 def step(self, state: State) -> tuple[MetricBase, State]:
375 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
376
377 trajectory, env_state = rollout(
378 self.env.step,
379 self.agent.compute_actions,
380 state.env_state,
381 state.agent_state,
382 rollout_key,
383 rollout_length=self.config.rollout_length,
384 env_extra_fields=("autoreset", "episode_return", "termination"),
385 )
386
387 agent_state = state.agent_state
388 if agent_state.obs_preprocessor_state is not None:
389 agent_state = agent_state.replace(
390 obs_preprocessor_state=running_statistics.update(
391 agent_state.obs_preprocessor_state,
392 trajectory.obs,
393 dp_axis_name=self.dp_axis_name,
394 )
395 )
396
397 train_episode_return = average_episode_discount_return(
398 trajectory.extras.env_extras.episode_return,
399 trajectory.dones,
400 dp_axis_name=self.dp_axis_name,
401 )
402
403 trajectory = tree_stop_gradient(trajectory)
404
405 def loss_fn(agent_state, sample_batch, key):
406 # learn all data from trajectory
407 loss_dict = self.agent.loss(agent_state, sample_batch, key)
408 loss_weights = self.config.loss_weights
409 loss = jnp.zeros(())
410 for loss_key in loss_weights.keys():
411 loss += loss_weights[loss_key] * loss_dict[loss_key]
412
413 return loss, loss_dict
414
415 update_fn = agent_gradient_update(
416 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True
417 )
418
419 # minibatch_size: num of envs in one batch
420 # unit in batch: trajectory [T, B//k, ...]
421 num_minibatches = self.config.num_envs // self.config.minibatch_size
422
423 def _get_shuffled_minibatch(perm_key, x):
424 # x: [T, B, ...] -> [k, T, B//k, ...]
425 x = jax.random.permutation(perm_key, x, axis=1)[
426 :, : num_minibatches * self.config.minibatch_size
427 ]
428 xs = jnp.stack(jnp.split(x, num_minibatches, axis=1))
429
430 return xs
431
432 def minibatch_step(carry, trajectory):
433 opt_state, agent_state, key = carry
434 key, learn_key = jax.random.split(key)
435
436 (loss, loss_dict), agent_state, opt_state = update_fn(
437 opt_state, agent_state, trajectory, learn_key
438 )
439
440 return (opt_state, agent_state, key), (loss, loss_dict)
441
442 def epoch_step(carry, _):
443 opt_state, agent_state, key = carry
444 shuffle_key, learn_key = jax.random.split(key)
445 batch_trajectory = jtu.tree_map(
446 partial(_get_shuffled_minibatch, shuffle_key), trajectory
447 )
448
449 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean(
450 minibatch_step,
451 (opt_state, agent_state, learn_key),
452 batch_trajectory,
453 length=num_minibatches,
454 )
455
456 return (opt_state, agent_state, key), (loss, loss_dict)
457
458 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean(
459 epoch_step,
460 (state.opt_state, agent_state, learn_key),
461 None,
462 length=self.config.reuse_rollout_epochs,
463 )
464
465 # ======== update metrics ========
466
467 sampled_timesteps = psum(
468 jnp.array(
469 self.config.rollout_length * self.config.num_envs, dtype=jnp.uint32
470 ),
471 axis_name=self.dp_axis_name,
472 )
473
474 sampled_epsiodes = psum(
475 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
476 )
477
478 workflow_metrics = state.metrics.replace(
479 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
480 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
481 iterations=state.metrics.iterations + 1,
482 ).all_reduce(dp_axis_name=self.dp_axis_name)
483
484 train_metrics = TrainMetric(
485 train_episode_return=train_episode_return,
486 loss=loss,
487 raw_loss_dict=loss_dict,
488 ).all_reduce(dp_axis_name=self.dp_axis_name)
489
490 return train_metrics, state.replace(
491 key=key,
492 metrics=workflow_metrics,
493 agent_state=agent_state,
494 env_state=env_state,
495 opt_state=opt_state,
496 )
497
[docs]
498 def learn(self, state: State) -> State:
499 one_step_timesteps = self.config.rollout_length * self.config.num_envs
500 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps)
501
502 start_iteration = state.metrics.iterations
503
504 for i in range(start_iteration, num_iters):
505 train_metrics, state = self.step(state)
506 workflow_metrics = state.metrics
507
508 iters = i + 1
509
510 self.recorder.write(workflow_metrics.to_local_dict(), iters)
511 train_metric_data = train_metrics.to_local_dict()
512 if train_metrics.train_episode_return == MISSING_REWARD:
513 train_metric_data["train_episode_return"] = None
514 self.recorder.write(train_metric_data, iters)
515
516 if iters % self.config.eval_interval == 0 or iters == num_iters:
517 eval_metrics, state = self.evaluate(state)
518 self.recorder.write(
519 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
520 )
521
522 self.checkpoint_manager.save(
523 iters,
524 state,
525 force=iters == num_iters,
526 )
527
528 return state
529
530
[docs]
531def compute_vtrace(
532 rho_t,
533 v_t,
534 v_t_plus_1,
535 rewards,
536 dones,
537 terminations,
538 discount=0.99,
539 lambda_=1.0,
540 clip_rho_threshold=1.0,
541 clip_c_threshold=1.0,
542):
543 chex.assert_trees_all_equal_shapes_and_dtypes(
544 rho_t, v_t, v_t_plus_1, rewards, dones
545 )
546
547 # clip c and rho
548 clipped_c_t = jnp.minimum(clip_c_threshold, rho_t) * lambda_
549 clipped_rho_t = jnp.minimum(clip_rho_threshold, rho_t)
550
551 # calculate δV_t
552 td_error = clipped_rho_t * (
553 rewards + discount * (1 - terminations) * v_t_plus_1 - v_t
554 )
555
556 # calculate delta = vtrace - v_t
557 def _compute_delta(delta, params):
558 td_error, discount, c = params
559 delta = td_error + discount * c * delta
560 return delta, delta
561
562 bootstrap_delta = jnp.zeros_like(v_t[-1])
563 _, delta = jax.lax.scan(
564 _compute_delta,
565 bootstrap_delta,
566 (td_error, discount * (1 - dones), clipped_c_t),
567 reverse=True,
568 unroll=16,
569 )
570
571 # calculate vs
572 vtrace = delta + v_t
573
574 return vtrace
575
576
[docs]
577def compute_pg_advantage(
578 vtrace,
579 v_t,
580 v_t_plus_1,
581 rewards,
582 terminations,
583 discount=0.99,
584 lambda_=1.0,
585 mode="official",
586):
587 discounts = discount * (1 - terminations)
588 # calculate advantage function
589 if mode == "official":
590 # Note: rllib also follows this implementation
591 gae_v_t_plus_1 = jnp.concatenate([vtrace[1:], v_t_plus_1[-1:]], axis=0)
592 elif mode == "acme":
593 gae_v_t_plus_1 = jnp.concatenate(
594 [lambda_ * vtrace[1:] + (1 - lambda_) * v_t[1:], v_t_plus_1[-1:]], axis=0
595 )
596 else:
597 raise ValueError(f"mode {mode} is not supported")
598
599 q_t = rewards + discounts * gae_v_t_plus_1
600 gae_adv = q_t - v_t
601
602 return gae_adv