1import logging
2import math
3from omegaconf import DictConfig
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10
11from evorl.replay_buffers import ReplayBuffer
12from evorl.distributed import agent_gradient_update
13from evorl.metrics import MetricBase
14from evorl.types import PyTreeDict, State
15from evorl.utils.jax_utils import (
16 tree_stop_gradient,
17 rng_split_like_tree,
18 right_shift_with_padding,
19 scan_and_mean,
20)
21from evorl.utils.rl_toolkits import soft_target_update, flatten_rollout_trajectory
22from evorl.evaluators import Evaluator, EpisodeCollector
23from evorl.agent import AgentState, Agent
24from evorl.envs import create_env, AutoresetMode
25from evorl.recorders import get_1d_array_statistics, add_prefix
26from evorl.ec.optimizers import ECState, OpenES, ExponentialScheduleSpec
27from evorl.algorithms.td3 import make_mlp_td3_agent, TD3TrainMetric
28from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
29
30from ..erl_workflow import ERLTrainMetric
31from .erl_td3_workflow import erl_replace_td3_actor_params, ERLTD3WorkflowTemplate
32
33logger = logging.getLogger(__name__)
34
35
[docs]
36class EvaluateMetric(MetricBase):
37 rl_episode_returns: chex.Array
38 rl_episode_lengths: chex.Array
39 pop_center_episode_returns: chex.Array
40 pop_center_episode_lengths: chex.Array
41
42
[docs]
43class ERLEDAWorkflow(ERLTD3WorkflowTemplate):
44 """ERL w/ EDA.
45
46 Configs:
47
48 - EC: n actors
49 - RL: 1 (actor,critic)
50 - Shared replay buffer
51
52 RL will be injected into the pop mean. Support all EDA based ES algorithms.
53 """
54
55 def __init__(self, **kwargs):
56 super().__init__(**kwargs)
57
58 # override
59 self._rl_update_fn = build_rl_update_fn(self.agent, self.optimizer, self.config)
60
[docs]
61 @classmethod
62 def name(cls):
63 return "ERL-EDA"
64
65 @classmethod
66 def _build_from_config(cls, config: DictConfig):
67 # env for rl&ec rollout
68 env = create_env(
69 config.env,
70 episode_length=config.env.max_episode_steps,
71 parallel=config.num_envs,
72 autoreset_mode=AutoresetMode.DISABLED,
73 record_ori_obs=True,
74 )
75
76 agent = make_mlp_td3_agent(
77 action_space=env.action_space,
78 norm_layer_type=config.agent_network.norm_layer_type,
79 num_critics=config.agent_network.num_critics,
80 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
81 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
82 discount=config.discount,
83 exploration_epsilon=config.exploration_epsilon,
84 policy_noise=config.policy_noise,
85 clip_policy_noise=config.clip_policy_noise,
86 critics_in_actor_loss=config.critics_in_actor_loss,
87 normalize_obs=config.normalize_obs,
88 )
89
90 if (
91 config.optimizer.grad_clip_norm is not None
92 and config.optimizer.grad_clip_norm > 0
93 ):
94 optimizer = optax.chain(
95 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
96 optax.adam(config.optimizer.lr),
97 )
98 else:
99 optimizer = optax.adam(config.optimizer.lr)
100
101 ec_optimizer = OpenES(
102 pop_size=config.pop_size,
103 lr_schedule=ExponentialScheduleSpec(**config.ec_lr),
104 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
105 mirror_sampling=config.mirror_sampling,
106 )
107
108 if config.fitness_with_exploration:
109 action_fn = agent.compute_actions
110 else:
111 action_fn = agent.evaluate_actions
112
113 ec_collector = EpisodeCollector(
114 env=env,
115 action_fn=action_fn,
116 max_episode_steps=config.env.max_episode_steps,
117 env_extra_fields=("ori_obs", "termination"),
118 )
119
120 if config.rl_exploration:
121 action_fn = agent.compute_actions
122 else:
123 action_fn = agent.evaluate_actions
124
125 rl_collector = EpisodeCollector(
126 env=env,
127 action_fn=action_fn,
128 max_episode_steps=config.env.max_episode_steps,
129 env_extra_fields=("ori_obs", "termination"),
130 )
131
132 replay_buffer = ReplayBuffer(
133 capacity=config.replay_buffer_capacity,
134 min_sample_timesteps=config.batch_size,
135 sample_batch_size=config.batch_size,
136 )
137
138 # to evaluate the pop-mean actor
139 eval_env = create_env(
140 config.env,
141 episode_length=config.env.max_episode_steps,
142 parallel=config.num_eval_envs,
143 autoreset_mode=AutoresetMode.DISABLED,
144 )
145
146 evaluator = Evaluator(
147 env=eval_env,
148 action_fn=agent.evaluate_actions,
149 max_episode_steps=config.env.max_episode_steps,
150 )
151
152 # this is only used for _ec_rollout()
153 agent_state_vmap_axes = AgentState(
154 params=0,
155 obs_preprocessor_state=None,
156 )
157
158 workflow = cls(
159 env=env,
160 agent=agent,
161 agent_state_vmap_axes=agent_state_vmap_axes,
162 optimizer=optimizer,
163 ec_optimizer=ec_optimizer,
164 ec_collector=ec_collector,
165 rl_collector=rl_collector,
166 evaluator=evaluator,
167 replay_buffer=replay_buffer,
168 config=config,
169 )
170
171 return workflow
172
173 def _setup_agent_and_optimizer(
174 self, key: chex.PRNGKey
175 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
176 agent_key, ec_key = jax.random.split(key)
177
178 # one agent for RL
179 agent_state = self.agent.init(
180 self.env.obs_space, self.env.action_space, agent_key
181 )
182
183 init_actor_params = agent_state.params.actor_params
184
185 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
186
187 opt_state = PyTreeDict(
188 actor=self.optimizer.init(agent_state.params.actor_params),
189 critic=self.optimizer.init(agent_state.params.critic_params),
190 )
191
192 return agent_state, opt_state, ec_opt_state
193
194 # override
195 def _rl_rollout(self, agent_state, replay_buffer_state, key):
196 # agnet_state: only contains one agent
197 # trajectory [T, B, ...]
198 eval_metrics, trajectory = self.rl_collector.rollout(
199 agent_state,
200 key,
201 self.config.rollout_episodes,
202 )
203
204 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1))
205 trajectory = clean_trajectory(trajectory)
206 trajectory, mask = tree_stop_gradient(
207 flatten_rollout_trajectory((trajectory, mask))
208 )
209 replay_buffer_state = self.replay_buffer.add(
210 replay_buffer_state, trajectory, mask
211 )
212
213 return eval_metrics, trajectory, replay_buffer_state
214
215 # override
216 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key):
217 def _sample_fn(key):
218 return self.replay_buffer.sample(replay_buffer_state, key)
219
220 def _sample_and_update_fn(carry, unused_t):
221 key, agent_state, opt_state = carry
222
223 key, rb_key, learn_key = jax.random.split(key, 3)
224
225 rb_keys = jax.random.split(rb_key, self.config.actor_update_interval)
226 # (actor_update_interval, B, ...)
227 sample_batches = jax.vmap(_sample_fn)(rb_keys)
228
229 (agent_state, opt_state), train_info = self._rl_update_fn(
230 agent_state, opt_state, sample_batches, learn_key
231 )
232
233 return (key, agent_state, opt_state), train_info
234
235 (
236 (_, agent_state, opt_state),
237 (
238 critic_loss,
239 actor_loss,
240 critic_loss_dict,
241 actor_loss_dict,
242 ),
243 ) = scan_and_mean(
244 _sample_and_update_fn,
245 (key, agent_state, opt_state),
246 (),
247 length=self.config.num_rl_updates_per_iter,
248 )
249
250 # smoothed td3 metrics
251 td3_metrics = TD3TrainMetric(
252 actor_loss=actor_loss,
253 critic_loss=critic_loss,
254 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
255 )
256
257 return td3_metrics, agent_state, opt_state
258
259 def _rl_injection(self, ec_opt_state, agent_state):
260 # update EC pop center with RL weights
261
262 pop_mean = ec_opt_state.mean
263 rl_actor_params = agent_state.params.actor_params
264
265 # Tips: x = x + stepsize * (y - x)
266 ec_opt_state = ec_opt_state.replace(
267 mean=optax.incremental_update(
268 rl_actor_params, pop_mean, self.config.rl_injection_stepsize
269 )
270 )
271
272 return ec_opt_state
273
[docs]
274 def step(self, state: State) -> tuple[MetricBase, State]:
275 pop_size = self.config.pop_size
276 agent_state = state.agent_state
277 opt_state = state.opt_state
278 ec_opt_state = state.ec_opt_state
279 replay_buffer_state = state.replay_buffer_state
280 iterations = state.metrics.iterations + 1
281
282 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split(
283 state.key, num=4
284 )
285
286 # ======== EC rollout ========
287 # the trajectory [#pop, T, B, ...]
288 # metrics: [#pop, B]
289 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
290
291 if self.config.mirror_sampling:
292 key, perm_key = jax.random.split(key)
293 pop_actor_params = jtu.tree_map(
294 lambda x, k: jax.random.permutation(k, x, axis=0),
295 pop_actor_params,
296 rng_split_like_tree(perm_key, pop_actor_params),
297 )
298
299 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params)
300 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout(
301 pop_agent_state, replay_buffer_state, ec_rollout_key
302 )
303
304 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
305 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
306
307 # ======== RL update ========
308 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout(
309 agent_state, replay_buffer_state, rl_rollout_key
310 )
311
312 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
313 rl_sampled_episodes = jnp.uint32(self.config.rollout_episodes)
314
315 td3_metrics, agent_state, opt_state = self._rl_update(
316 agent_state, opt_state, replay_buffer_state, learn_key
317 )
318
319 # ======== EC update ========
320 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1)
321 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
322
323 ec_opt_state = jax.lax.cond(
324 iterations % self.config.rl_injection_interval == 0,
325 self._rl_injection,
326 lambda ec_opt_state, agent_state: ec_opt_state,
327 ec_opt_state,
328 agent_state,
329 )
330
331 train_metrics = ERLTrainMetric(
332 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1),
333 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1),
334 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1),
335 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1),
336 rl_metrics=td3_metrics,
337 ec_info=ec_metrics,
338 rb_size=replay_buffer_state.buffer_size,
339 )
340
341 sampled_timesteps = ec_sampled_episodes + rl_sampled_timesteps
342 sampled_episodes = ec_sampled_timesteps + rl_sampled_episodes
343 workflow_metrics = state.metrics.replace(
344 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
345 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
346 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps
347 + rl_sampled_timesteps,
348 iterations=iterations,
349 )
350
351 state = state.replace(
352 key=key,
353 metrics=workflow_metrics,
354 agent_state=agent_state,
355 replay_buffer_state=replay_buffer_state,
356 ec_opt_state=ec_opt_state,
357 opt_state=opt_state,
358 )
359
360 return train_metrics, state
361
[docs]
362 def evaluate(self, state: State) -> tuple[MetricBase, State]:
363 key, rl_eval_key, ec_eval_key = jax.random.split(state.key, num=3)
364
365 rl_eval_metrics = self.evaluator.evaluate(
366 state.agent_state, rl_eval_key, num_episodes=self.config.eval_episodes
367 )
368
369 pop_mean_actor_params = state.ec_opt_state.mean
370
371 pop_mean_agent_state = erl_replace_td3_actor_params(
372 state.agent_state, pop_mean_actor_params
373 )
374
375 ec_eval_metrics = self.evaluator.evaluate(
376 pop_mean_agent_state, ec_eval_key, num_episodes=self.config.eval_episodes
377 )
378
379 eval_metrics = EvaluateMetric(
380 rl_episode_returns=rl_eval_metrics.episode_returns.mean(),
381 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(),
382 pop_center_episode_returns=ec_eval_metrics.episode_returns.mean(),
383 pop_center_episode_lengths=ec_eval_metrics.episode_lengths.mean(),
384 )
385
386 state = state.replace(key=key)
387
388 return eval_metrics, state
389
[docs]
390 def learn(self, state: State) -> State:
391 sampled_episodes_per_iter = (
392 self.config.episodes_for_fitness * self.config.pop_size
393 + self.config.rollout_episodes
394 )
395 num_iters = math.ceil(
396 (self.config.total_episodes - state.metrics.sampled_episodes)
397 / sampled_episodes_per_iter
398 )
399
400 final_iteration = num_iters + state.metrics.iterations
401 for i in range(state.metrics.iterations, final_iteration):
402 iters = i + 1
403 train_metrics, state = self.step(state)
404 workflow_metrics = state.metrics
405
406 workflow_metrics_dict = workflow_metrics.to_local_dict()
407 self.recorder.write(workflow_metrics_dict, iters)
408
409 train_metrics_dict = train_metrics.to_local_dict()
410 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
411 train_metrics_dict["pop_episode_returns"], histogram=True
412 )
413
414 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
415 train_metrics_dict["pop_episode_lengths"], histogram=True
416 )
417
418 self.recorder.write(train_metrics_dict, iters)
419
420 if iters % self.config.eval_interval == 0 or iters == final_iteration:
421 eval_metrics, state = self.evaluate(state)
422
423 self.recorder.write(
424 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
425 )
426
427 saved_state = state
428 if not self.config.save_replay_buffer:
429 saved_state = skip_replay_buffer_state(saved_state)
430
431 self.checkpoint_manager.save(
432 iters,
433 saved_state,
434 force=iters == final_iteration,
435 )
436
437 return state
438
439
[docs]
440def build_rl_update_fn(
441 agent: Agent,
442 optimizer: optax.GradientTransformation,
443 config: DictConfig,
444):
445 def critic_loss_fn(agent_state, sample_batch, key):
446 # loss on a single critic with multiple actors
447 # sample_batch: (B, ...)
448
449 loss_dict = agent.critic_loss(agent_state, sample_batch, key)
450
451 loss = loss_dict.critic_loss
452
453 return loss, loss_dict
454
455 def actor_loss_fn(agent_state, sample_batch, key):
456 # loss on a single actor
457 # different actor shares same sample_batch (B, ...) input
458 loss_dict = agent.actor_loss(agent_state, sample_batch, key)
459
460 loss = loss_dict.actor_loss
461
462 return loss, loss_dict
463
464 critic_update_fn = agent_gradient_update(
465 critic_loss_fn,
466 optimizer,
467 has_aux=True,
468 attach_fn=lambda agent_state, critic_params: agent_state.replace(
469 params=agent_state.params.replace(critic_params=critic_params)
470 ),
471 detach_fn=lambda agent_state: agent_state.params.critic_params,
472 )
473
474 actor_update_fn = agent_gradient_update(
475 actor_loss_fn,
476 optimizer,
477 has_aux=True,
478 attach_fn=lambda agent_state, actor_params: agent_state.replace(
479 params=agent_state.params.replace(actor_params=actor_params)
480 ),
481 detach_fn=lambda agent_state: agent_state.params.actor_params,
482 )
483
484 def _update_fn(agent_state, opt_state, sample_batches, key):
485 critic_opt_state = opt_state.critic
486 actor_opt_state = opt_state.actor
487
488 key, critic_key, actor_key = jax.random.split(key, num=3)
489
490 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches)
491 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches)
492
493 if config.actor_update_interval - 1 > 0:
494
495 def _update_critic_fn(carry, sample_batch):
496 key, agent_state, critic_opt_state = carry
497
498 key, critic_key = jax.random.split(key)
499
500 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
501 critic_update_fn(
502 critic_opt_state, agent_state, sample_batch, critic_key
503 )
504 )
505
506 return (key, agent_state, critic_opt_state), None
507
508 key, critic_multiple_update_key = jax.random.split(key)
509
510 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
511 _update_critic_fn,
512 (
513 critic_multiple_update_key,
514 agent_state,
515 critic_opt_state,
516 ),
517 critic_sample_batches,
518 length=config.actor_update_interval - 1,
519 )
520
521 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
522 critic_update_fn(
523 critic_opt_state, agent_state, last_sample_batch, critic_key
524 )
525 )
526
527 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn(
528 actor_opt_state, agent_state, last_sample_batch, actor_key
529 )
530
531 # not need vmap
532 target_actor_params = soft_target_update(
533 agent_state.params.target_actor_params,
534 agent_state.params.actor_params,
535 config.tau,
536 )
537 target_critic_params = soft_target_update(
538 agent_state.params.target_critic_params,
539 agent_state.params.critic_params,
540 config.tau,
541 )
542 agent_state = agent_state.replace(
543 params=agent_state.params.replace(
544 target_actor_params=target_actor_params,
545 target_critic_params=target_critic_params,
546 )
547 )
548
549 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state)
550
551 return (
552 (agent_state, opt_state),
553 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
554 )
555
556 return _update_fn