1import logging
2from omegaconf import DictConfig
3import math
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10
11from evorl.agent import AgentState, Agent
12from evorl.distributed import agent_gradient_update
13from evorl.types import PyTreeDict, State
14from evorl.utils.jax_utils import (
15 right_shift_with_padding,
16 tree_stop_gradient,
17 scan_and_last,
18 scan_and_mean,
19)
20from evorl.utils.rl_toolkits import (
21 flatten_rollout_trajectory,
22 flatten_pop_rollout_episode,
23 soft_target_update,
24)
25from evorl.recorders import get_1d_array_statistics
26from evorl.algorithms.td3 import TD3TrainMetric, TD3NetworkParams
27
28from ..erl_workflow import ERLWorkflowBase, ERLTrainMetric
29
30logger = logging.getLogger(__name__)
31
32
[docs]
33class ERLTD3WorkflowTemplate(ERLWorkflowBase):
34 """A template for ERL workflow on TD3 Agent."""
35
36 # Note: Turn off the warmup logging in PBT or parallel training
37 LOGGING_WARMUP_FLAG = True
38
39 def __init__(self, **kwargs):
40 super().__init__(**kwargs)
41 self._rl_update_fn = build_erl_rl_update_fn(
42 self.agent,
43 self.optimizer,
44 self.config,
45 self.agent_state_vmap_axes,
46 )
47
[docs]
48 def setup(self, key: chex.PRNGKey) -> State:
49 state = super().setup(key)
50
51 # Note: we assume
52 if self.config.warmup_iters > 0:
53 logger.info("Start warmup")
54
55 def _warmup_step(state, unused_t):
56 train_metrics, state = self.warmup_step(state)
57 return state, train_metrics
58
59 def _logging(train_metrics, iters):
60 if self.LOGGING_WARMUP_FLAG:
61 train_metrics_dict = train_metrics.to_local_dict()
62 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
63 train_metrics_dict["pop_episode_returns"], histogram=True
64 )
65
66 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
67 train_metrics_dict["pop_episode_lengths"], histogram=True
68 )
69 del train_metrics_dict["rl_episode_lengths"]
70 del train_metrics_dict["rl_episode_returns"]
71 del train_metrics_dict["rl_metrics"]
72 self.recorder.write(train_metrics_dict, state.metrics.iterations)
73
74 num_fold_iters = math.floor(
75 self.config.warmup_iters / self.config.eval_interval
76 )
77 last_fold_iters = self.config.warmup_iters % self.config.eval_interval
78
79 for i in range(num_fold_iters):
80 state, train_metrics = scan_and_last(
81 _warmup_step, state, (), length=self.config.eval_interval
82 )
83 _logging(train_metrics, state.metrics.iterations)
84
85 if last_fold_iters > 0:
86 state, train_metrics = scan_and_last(
87 _warmup_step, state, (), length=last_fold_iters
88 )
89 _logging(train_metrics, state.metrics.iterations)
90
91 logger.info("Complete warmup")
92
93 return state
94
[docs]
95 def warmup_step(self, state: State) -> tuple[ERLTrainMetric, State]:
96 pop_size = self.config.pop_size
97 agent_state = state.agent_state
98 ec_opt_state = state.ec_opt_state
99 replay_buffer_state = state.replay_buffer_state
100
101 key, ec_rollout_key = jax.random.split(state.key, 2)
102
103 # 1. ask()
104 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
105 # 2. evaluate()
106 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params)
107 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout(
108 pop_agent_state, replay_buffer_state, ec_rollout_key
109 )
110 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1)
111 # 3. tell()
112 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
113
114 train_metrics = ERLTrainMetric(
115 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1),
116 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1),
117 ec_info=ec_metrics,
118 rb_size=replay_buffer_state.buffer_size,
119 )
120
121 # calculate the number of timestep
122 sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
123 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
124
125 workflow_metrics = state.metrics.replace(
126 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
127 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
128 iterations=state.metrics.iterations + 1,
129 )
130
131 state = state.replace(
132 key=key,
133 metrics=workflow_metrics,
134 replay_buffer_state=replay_buffer_state,
135 ec_opt_state=ec_opt_state,
136 )
137
138 return train_metrics, state
139
140 def _ec_rollout(self, agent_state, replay_buffer_state, key):
141 return rollout_episode(
142 agent_state,
143 replay_buffer_state,
144 key,
145 collector=self.ec_collector,
146 replay_buffer=self.replay_buffer,
147 agent_state_vmap_axes=self.agent_state_vmap_axes,
148 num_agents=self.config.pop_size,
149 num_episodes=self.config.episodes_for_fitness,
150 )
151
152 def _rl_rollout(self, agent_state, replay_buffer_state, key):
153 return rollout_episode(
154 agent_state,
155 replay_buffer_state,
156 key,
157 collector=self.rl_collector,
158 replay_buffer=self.replay_buffer,
159 agent_state_vmap_axes=self.agent_state_vmap_axes,
160 num_agents=self.config.num_rl_agents,
161 num_episodes=self.config.rollout_episodes,
162 )
163
164 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key):
165 def _sample_fn(key):
166 return self.replay_buffer.sample(replay_buffer_state, key)
167
168 def _sample_and_update_fn(carry, unused_t):
169 key, agent_state, opt_state = carry
170
171 key, rb_key, learn_key = jax.random.split(key, 3)
172
173 rb_keys = jax.random.split(
174 rb_key, self.config.actor_update_interval * self.config.num_rl_agents
175 )
176 sample_batches = jax.vmap(_sample_fn)(rb_keys)
177
178 # (actor_update_interval, num_rl_agents, B, ...)
179 sample_batches = jtu.tree_map(
180 lambda x: x.reshape(
181 (
182 self.config.actor_update_interval,
183 self.config.num_rl_agents,
184 *x.shape[1:],
185 )
186 ),
187 sample_batches,
188 )
189
190 (agent_state, opt_state), train_info = self._rl_update_fn(
191 agent_state, opt_state, sample_batches, learn_key
192 )
193
194 return (key, agent_state, opt_state), train_info
195
196 (
197 (_, agent_state, opt_state),
198 (
199 critic_loss,
200 actor_loss,
201 critic_loss_dict,
202 actor_loss_dict,
203 ),
204 ) = scan_and_mean(
205 _sample_and_update_fn,
206 (key, agent_state, opt_state),
207 (),
208 length=self.config.num_rl_updates_per_iter,
209 )
210
211 # smoothed td3 metrics
212 td3_metrics = TD3TrainMetric(
213 actor_loss=actor_loss,
214 critic_loss=critic_loss,
215 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
216 )
217
218 return td3_metrics, agent_state, opt_state
219
220
[docs]
221def erl_replace_td3_actor_params(
222 agent_state: AgentState, pop_actor_params: TD3NetworkParams
223) -> AgentState:
224 return agent_state.replace(
225 params=TD3NetworkParams(
226 actor_params=pop_actor_params,
227 target_actor_params=pop_actor_params,
228 critic_params=None,
229 target_critic_params=None,
230 )
231 )
232
233
234DUMMY_TD3_TRAINMETRIC = TD3TrainMetric(
235 critic_loss=jnp.zeros(()),
236 actor_loss=jnp.zeros(()),
237 raw_loss_dict=PyTreeDict(
238 critic_loss=jnp.zeros(()),
239 q_value=jnp.zeros(()),
240 actor_loss=jnp.zeros(()),
241 ),
242)
243
244
[docs]
245def create_dummy_td3_trainmetric(num: int) -> TD3TrainMetric:
246 if num >= 1:
247 return DUMMY_TD3_TRAINMETRIC.replace(
248 raw_loss_dict=jtu.tree_map(
249 lambda x: jnp.broadcast_to(x, (num, *x.shape)),
250 DUMMY_TD3_TRAINMETRIC.raw_loss_dict,
251 )
252 )
253 else:
254 raise ValueError(f"num should be positive, got {num}")
255
256
[docs]
257def rollout_episode(
258 agent_state: AgentState,
259 replay_buffer_state,
260 key,
261 *,
262 collector,
263 replay_buffer,
264 agent_state_vmap_axes,
265 num_episodes,
266 num_agents,
267):
268 eval_metrics, trajectory = jax.vmap(
269 collector.rollout,
270 in_axes=(agent_state_vmap_axes, 0, None),
271 )(
272 agent_state,
273 jax.random.split(key, num_agents),
274 num_episodes,
275 )
276
277 # [n, T, B, ...] -> [T, n*B, ...]
278 trajectory = trajectory.replace(next_obs=None)
279 trajectory = flatten_pop_rollout_episode(trajectory)
280
281 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1))
282 trajectory = trajectory.replace(dones=None)
283 trajectory, mask = tree_stop_gradient(
284 flatten_rollout_trajectory((trajectory, mask))
285 )
286 replay_buffer_state = replay_buffer.add(replay_buffer_state, trajectory, mask)
287
288 return eval_metrics, trajectory, replay_buffer_state
289
290
[docs]
291def build_erl_rl_update_fn(
292 agent: Agent,
293 optimizer: optax.GradientTransformation,
294 config: DictConfig,
295 agent_state_vmap_axes: AgentState,
296):
297 """K (actor, critic) pairs."""
298 num_rl_agents = config.num_rl_agents
299
300 def critic_loss_fn(agent_state, sample_batch, key):
301 # loss on a single critic with multiple actors
302 # sample_batch: (n, B, ...)
303
304 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
305 agent_state, sample_batch, jax.random.split(key, num_rl_agents)
306 )
307
308 loss = loss_dict.critic_loss.sum()
309
310 return loss, loss_dict
311
312 def actor_loss_fn(agent_state, sample_batch, key):
313 # loss on a single actor
314 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
315 agent_state, sample_batch, jax.random.split(key, num_rl_agents)
316 )
317
318 loss = loss_dict.actor_loss.sum()
319
320 return loss, loss_dict
321
322 critic_update_fn = agent_gradient_update(
323 critic_loss_fn,
324 optimizer,
325 has_aux=True,
326 attach_fn=lambda agent_state, critic_params: agent_state.replace(
327 params=agent_state.params.replace(critic_params=critic_params)
328 ),
329 detach_fn=lambda agent_state: agent_state.params.critic_params,
330 )
331
332 actor_update_fn = agent_gradient_update(
333 actor_loss_fn,
334 optimizer,
335 has_aux=True,
336 attach_fn=lambda agent_state, actor_params: agent_state.replace(
337 params=agent_state.params.replace(actor_params=actor_params)
338 ),
339 detach_fn=lambda agent_state: agent_state.params.actor_params,
340 )
341
342 def _update_fn(agent_state, opt_state, sample_batches, key):
343 critic_opt_state = opt_state.critic
344 actor_opt_state = opt_state.actor
345
346 key, critic_key, actor_key = jax.random.split(key, num=3)
347
348 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches)
349 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches)
350
351 if config.actor_update_interval - 1 > 0:
352
353 def _update_critic_fn(carry, sample_batch):
354 key, agent_state, critic_opt_state = carry
355
356 key, critic_key = jax.random.split(key)
357
358 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
359 critic_update_fn(
360 critic_opt_state, agent_state, sample_batch, critic_key
361 )
362 )
363
364 return (key, agent_state, critic_opt_state), None
365
366 key, critic_multiple_update_key = jax.random.split(key)
367
368 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
369 _update_critic_fn,
370 (
371 critic_multiple_update_key,
372 agent_state,
373 critic_opt_state,
374 ),
375 critic_sample_batches,
376 length=config.actor_update_interval - 1,
377 )
378
379 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
380 critic_update_fn(
381 critic_opt_state, agent_state, last_sample_batch, critic_key
382 )
383 )
384
385 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn(
386 actor_opt_state, agent_state, last_sample_batch, actor_key
387 )
388
389 # not need vmap
390 target_actor_params = soft_target_update(
391 agent_state.params.target_actor_params,
392 agent_state.params.actor_params,
393 config.tau,
394 )
395 target_critic_params = soft_target_update(
396 agent_state.params.target_critic_params,
397 agent_state.params.critic_params,
398 config.tau,
399 )
400 agent_state = agent_state.replace(
401 params=agent_state.params.replace(
402 target_actor_params=target_actor_params,
403 target_critic_params=target_critic_params,
404 )
405 )
406
407 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state)
408
409 return (
410 (agent_state, opt_state),
411 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
412 )
413
414 return _update_fn