1import logging
2from functools import partial
3import math
4from typing_extensions import Self # pytype: disable=not-supported-yet]
5from omegaconf import DictConfig
6
7import chex
8import jax
9import jax.numpy as jnp
10import jax.tree_util as jtu
11
12from evorl.distributed import (
13 agent_gradient_update,
14 psum,
15)
16from evorl.agent import AgentState, RandomAgent
17from evorl.types import PyTreeDict, State
18from evorl.metrics import MetricBase, EvaluateMetric
19from evorl.rollout import rollout
20from evorl.sample_batch import SampleBatch
21from evorl.utils import running_statistics
22from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean
23from evorl.utils.rl_toolkits import soft_target_update, flatten_rollout_trajectory
24from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
25
26from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
27from evorl.algorithms.td3 import TD3TrainMetric, TD3Workflow
28
29
30logger = logging.getLogger(__name__)
31
32
[docs]
33class WorkflowMetric(MetricBase):
34 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32)
35 sampled_timesteps_per_agent: chex.Array = jnp.zeros((), dtype=jnp.uint32)
36 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
37 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
38
39
[docs]
40class PopTD3Workflow(TD3Workflow):
41 """Indepentent TD3 agent with shared replay buffer."""
42
[docs]
43 @classmethod
44 def name(cls):
45 return "PopTD3"
46
[docs]
47 @classmethod
48 def build_from_config(
49 cls,
50 config: DictConfig,
51 enable_multi_devices: bool = False,
52 enable_jit: bool = True,
53 ) -> Self:
54 devices = jax.local_devices()
55
56 if enable_multi_devices or len(devices) > 1:
57 raise NotImplementedError("Multi-devices is not supported yet.")
58
59 return super().build_from_config(config, enable_multi_devices, enable_jit)
60
61 def _setup_workflow_metrics(self) -> MetricBase:
62 return WorkflowMetric()
63
[docs]
64 def setup(self, key: chex.PRNGKey) -> State:
65 key, agent_key, env_key, rb_key = jax.random.split(key, 4)
66
67 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key)
68 workflow_metrics = self._setup_workflow_metrics()
69
70 # TODO: what about using shared init env_state?
71 env_state = jax.vmap(self.env.reset)(
72 jax.random.split(env_key, self.config.pop_size)
73 )
74 replay_buffer_state = self._setup_replaybuffer(rb_key)
75
76 state = State(
77 key=key,
78 metrics=workflow_metrics,
79 agent_state=agent_state,
80 env_state=env_state,
81 opt_state=opt_state,
82 replay_buffer_state=replay_buffer_state,
83 )
84
85 logger.info("Start replay buffer post-setup")
86
87 state = self._postsetup_replaybuffer(state)
88
89 logger.info("Complete replay buffer post-setup")
90
91 return state
92
93 def _setup_agent_and_optimizer(
94 self, key: chex.PRNGKey
95 ) -> tuple[AgentState, chex.ArrayTree]:
96 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))(
97 self.env.obs_space,
98 self.env.action_space,
99 jax.random.split(key, self.config.pop_size),
100 )
101
102 def _opt_init(agent_state):
103 return PyTreeDict(
104 actor=self.optimizer.init(agent_state.params.actor_params),
105 critic=self.optimizer.init(agent_state.params.critic_params),
106 )
107
108 opt_state = jax.vmap(_opt_init)(agent_state)
109
110 return agent_state, opt_state
111
112 def _postsetup_replaybuffer(self, state: State) -> State:
113 action_space = self.env.action_space
114 obs_space = self.env.obs_space
115 config = self.config
116 replay_buffer_state = state.replay_buffer_state
117 agent_state = state.agent_state
118
119 def _rollout(agent, agent_state, key, rollout_length):
120 env_key, rollout_key = jax.random.split(key)
121
122 env_state = self.env.reset(env_key)
123
124 trajectory, env_state = rollout(
125 env_fn=self.env.step,
126 action_fn=agent.compute_actions,
127 env_state=env_state,
128 agent_state=agent_state,
129 key=rollout_key,
130 rollout_length=rollout_length,
131 env_extra_fields=("ori_obs", "termination"),
132 )
133
134 # [T, B, ...] -> [T*B, ...]
135 trajectory = clean_trajectory(trajectory)
136 trajectory = flatten_rollout_trajectory(trajectory)
137 trajectory = tree_stop_gradient(trajectory)
138
139 return trajectory
140
141 def _update_obs_preprocessor(agent_state, trajectory):
142 if (
143 agent_state.obs_preprocessor_state is not None
144 and len(trajectory.obs) > 0
145 ):
146 agent_state = agent_state.replace(
147 obs_preprocessor_state=running_statistics.update(
148 agent_state.obs_preprocessor_state,
149 trajectory.obs,
150 dp_axis_name=self.dp_axis_name,
151 )
152 )
153 return agent_state
154
155 # ==== fill random transitions ====
156
157 key, random_rollout_key, rollout_key = jax.random.split(state.key, num=3)
158 random_agent = RandomAgent()
159 random_agent_state = random_agent.init(
160 obs_space, action_space, jax.random.PRNGKey(0)
161 )
162 rollout_length = config.random_timesteps // config.num_envs
163
164 trajectory = _rollout(
165 random_agent,
166 random_agent_state,
167 key=random_rollout_key,
168 rollout_length=rollout_length,
169 )
170
171 agent_state = jax.vmap(_update_obs_preprocessor, in_axes=(0, None))(
172 agent_state, trajectory
173 )
174
175 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
176
177 rollout_timesteps = rollout_length * config.num_envs
178 sampled_timesteps = psum(
179 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name
180 )
181
182 # ==== fill tansition state from init agents (diff from TD3) ====
183 rollout_length = math.ceil(
184 (config.learning_start_timesteps - rollout_timesteps)
185 / (config.num_envs * config.pop_size)
186 )
187
188 _vmap_rollout = jax.vmap(
189 partial(_rollout, self.agent, rollout_length=rollout_length)
190 )
191
192 trajectory = _vmap_rollout(
193 agent_state, jax.random.split(rollout_key, config.pop_size)
194 )
195 agent_state = jax.vmap(_update_obs_preprocessor)(agent_state, trajectory)
196
197 # [#pop, T*B] -> [#pop*T*B, ...]
198 trajectory = flatten_rollout_trajectory(trajectory)
199 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
200
201 rollout_timesteps = rollout_length * config.num_envs * config.pop_size
202 sampled_timesteps = sampled_timesteps + psum(
203 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name
204 )
205
206 workflow_metrics = state.metrics.replace(
207 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
208 ).all_reduce(dp_axis_name=self.dp_axis_name)
209
210 return state.replace(
211 key=key,
212 metrics=workflow_metrics,
213 agent_state=agent_state,
214 replay_buffer_state=replay_buffer_state,
215 )
216
[docs]
217 def step(self, state: State) -> tuple[MetricBase, State]:
218 pop_size = self.config.pop_size
219 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
220
221 _rollout = partial(
222 rollout,
223 self.env.step,
224 self.agent.compute_actions,
225 rollout_length=self.config.rollout_length,
226 env_extra_fields=("ori_obs", "termination"),
227 )
228
229 # the trajectory [#pop, T, B, ...]
230 trajectory, env_state = jax.vmap(_rollout)(
231 state.env_state, state.agent_state, jax.random.split(rollout_key, pop_size)
232 )
233
234 trajectory_dones = trajectory.dones
235 trajectory = clean_trajectory(trajectory)
236 trajectory = flatten_pop_rollout_trajectory(trajectory)
237 trajectory = tree_stop_gradient(trajectory)
238
239 agent_state = state.agent_state
240 if agent_state.obs_preprocessor_state is not None:
241 agent_state = agent_state.replace(
242 obs_preprocessor_state=running_statistics.update(
243 agent_state.obs_preprocessor_state,
244 trajectory.obs,
245 dp_axis_name=self.dp_axis_name,
246 )
247 )
248
249 replay_buffer_state = self.replay_buffer.add(
250 state.replay_buffer_state, trajectory
251 )
252
253 def critic_loss_fn(agent_state, sample_batch, key):
254 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
255
256 loss = loss_dict.critic_loss
257 return loss, loss_dict
258
259 def actor_loss_fn(agent_state, sample_batch, key):
260 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
261
262 loss = loss_dict.actor_loss
263 return loss, loss_dict
264
265 critic_update_fn = agent_gradient_update(
266 critic_loss_fn,
267 self.optimizer,
268 dp_axis_name=self.dp_axis_name,
269 has_aux=True,
270 attach_fn=lambda agent_state, critic_params: agent_state.replace(
271 params=agent_state.params.replace(critic_params=critic_params)
272 ),
273 detach_fn=lambda agent_state: agent_state.params.critic_params,
274 )
275
276 actor_update_fn = agent_gradient_update(
277 actor_loss_fn,
278 self.optimizer,
279 dp_axis_name=self.dp_axis_name,
280 has_aux=True,
281 attach_fn=lambda agent_state, actor_params: agent_state.replace(
282 params=agent_state.params.replace(actor_params=actor_params)
283 ),
284 detach_fn=lambda agent_state: agent_state.params.actor_params,
285 )
286
287 critic_update_fn = jax.vmap(critic_update_fn, in_axes=(0, 0, None, 0))
288 actor_update_fn = jax.vmap(actor_update_fn, in_axes=(0, 0, None, 0))
289
290 def _sample_and_update_fn(carry, unused_t):
291 key, agent_state, opt_state = carry
292
293 critic_opt_state = opt_state.critic
294 actor_opt_state = opt_state.actor
295
296 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4)
297
298 if self.config.actor_update_interval - 1 > 0:
299
300 def _sample_and_update_critic_fn(carry, unused_t):
301 key, agent_state, critic_opt_state = carry
302
303 key, rb_key, critic_key = jax.random.split(key, num=3)
304 # it's safe to use read-only replay_buffer_state here.
305 sample_batch = self.replay_buffer.sample(
306 replay_buffer_state, rb_key
307 )
308
309 critic_key = jax.random.split(critic_key, pop_size)
310
311 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
312 critic_update_fn(
313 critic_opt_state, agent_state, sample_batch, critic_key
314 )
315 )
316
317 return (key, agent_state, critic_opt_state), None
318
319 key, critic_multiple_update_key = jax.random.split(key)
320
321 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
322 _sample_and_update_critic_fn,
323 (critic_multiple_update_key, agent_state, critic_opt_state),
324 (),
325 length=self.config.actor_update_interval - 1,
326 )
327
328 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
329
330 critic_key = jax.random.split(critic_key, pop_size)
331 actor_key = jax.random.split(actor_key, pop_size)
332
333 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
334 critic_update_fn(
335 critic_opt_state, agent_state, sample_batch, critic_key
336 )
337 )
338
339 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
340 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key)
341 )
342
343 target_actor_params = soft_target_update(
344 agent_state.params.target_actor_params,
345 agent_state.params.actor_params,
346 self.config.tau,
347 )
348 target_critic_params = soft_target_update(
349 agent_state.params.target_critic_params,
350 agent_state.params.critic_params,
351 self.config.tau,
352 )
353 agent_state = agent_state.replace(
354 params=agent_state.params.replace(
355 target_actor_params=target_actor_params,
356 target_critic_params=target_critic_params,
357 )
358 )
359
360 opt_state = opt_state.replace(
361 actor=actor_opt_state, critic=critic_opt_state
362 )
363
364 return (
365 (key, agent_state, opt_state),
366 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
367 )
368
369 (
370 (_, agent_state, opt_state),
371 (
372 critic_loss,
373 actor_loss,
374 critic_loss_dict,
375 actor_loss_dict,
376 ),
377 ) = scan_and_mean(
378 _sample_and_update_fn,
379 (learn_key, agent_state, state.opt_state),
380 (),
381 length=self.config.num_updates_per_iter,
382 )
383
384 # [#pop, ...]
385 train_metrics = TD3TrainMetric(
386 actor_loss=actor_loss,
387 critic_loss=critic_loss,
388 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
389 ).all_reduce(dp_axis_name=self.dp_axis_name)
390
391 # calculate the number of timestep
392 sampled_timesteps_per_agent = psum(
393 jnp.uint32(self.config.rollout_length * self.config.num_envs),
394 axis_name=self.dp_axis_name,
395 )
396 sampled_timesteps = psum(
397 jnp.uint32(
398 self.config.rollout_length * self.config.num_envs * self.config.pop_size
399 ),
400 axis_name=self.dp_axis_name,
401 )
402 sampled_epsiodes = psum(
403 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
404 )
405
406 # iterations is the number of updates of the agent
407 workflow_metrics = state.metrics.replace(
408 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
409 sampled_timesteps_per_agent=state.metrics.sampled_timesteps_per_agent
410 + sampled_timesteps_per_agent,
411 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
412 iterations=state.metrics.iterations + 1,
413 ).all_reduce(dp_axis_name=self.dp_axis_name)
414
415 return train_metrics, state.replace(
416 key=key,
417 metrics=workflow_metrics,
418 agent_state=agent_state,
419 env_state=env_state,
420 replay_buffer_state=replay_buffer_state,
421 opt_state=opt_state,
422 )
423
[docs]
424 def evaluate(self, state: State) -> tuple[MetricBase, State]:
425 key, eval_key = jax.random.split(state.key, num=2)
426
427 # [#pop, #episodes]
428 raw_eval_metrics = jax.vmap(
429 partial(self.evaluator.evaluate, num_episodes=self.config.eval_episodes),
430 )(
431 state.agent_state,
432 jax.random.split(eval_key, self.config.pop_size),
433 )
434
435 eval_metrics = EvaluateMetric(
436 episode_returns=raw_eval_metrics.episode_returns.mean(-1),
437 episode_lengths=raw_eval_metrics.episode_lengths.mean(-1),
438 ).all_reduce(dp_axis_name=self.dp_axis_name)
439
440 state = state.replace(key=key)
441 return eval_metrics, state
442
[docs]
443 def learn(self, state: State) -> State:
444 num_devices = jax.device_count()
445 one_step_timesteps = (
446 self.config.rollout_length * self.config.num_envs * self.config.pop_size
447 )
448 sampled_timesteps = state.metrics.sampled_timesteps.tolist()
449 num_iters = math.ceil(
450 (self.config.total_timesteps - sampled_timesteps)
451 / (one_step_timesteps * self.config.fold_iters * num_devices)
452 )
453 start_iteration = state.metrics.iterations.tolist()
454 final_iteration = num_iters + start_iteration
455
456 for i in range(num_iters):
457 train_metrics, state = self._multi_steps(state)
458 workflow_metrics = state.metrics
459
460 # current iteration
461 iterations = state.metrics.iterations.tolist()
462 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
463
464 train_metrics_dict = jtu.tree_map(
465 partial(get_1d_array_statistics, histogram=True),
466 train_metrics.to_local_dict(),
467 )
468
469 self.recorder.write(train_metrics_dict, iterations)
470
471 if (
472 iterations % self.config.eval_interval == 0
473 or iterations == final_iteration
474 ):
475 eval_metrics, state = self.evaluate(state)
476
477 eval_metrics_dict = jtu.tree_map(
478 get_1d_array,
479 eval_metrics.to_local_dict(),
480 )
481
482 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iterations)
483
484 saved_state = state
485 if not self.config.save_replay_buffer:
486 saved_state = skip_replay_buffer_state(saved_state)
487 self.checkpoint_manager.save(
488 iterations,
489 saved_state,
490 force=iterations == final_iteration,
491 )
492
493 return state
494
495
[docs]
496def flatten_pop_rollout_trajectory(trajectory: SampleBatch) -> SampleBatch:
497 """Flatten the trajectory from [#pop, T, B, ...] to [#pop*T*B, ...]."""
498 return jtu.tree_map(lambda x: jax.lax.collapse(x, 0, 3), trajectory)