1from functools import partial
2import math
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4from omegaconf import DictConfig
5
6import chex
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11
12from evorl.distributed import agent_gradient_update
13from evorl.agent import Agent, AgentState
14from evorl.types import PyTreeDict, State
15from evorl.metrics import MetricBase, EvaluateMetric
16from evorl.replay_buffers import ReplayBuffer
17from evorl.envs import create_env, AutoresetMode
18from evorl.utils.rl_toolkits import (
19 soft_target_update,
20)
21from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
22from evorl.evaluators import Evaluator, EpisodeCollector
23
24
25from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
26from evorl.algorithms.td3 import make_mlp_td3_agent
27from evorl.algorithms.erl.cemrl_td3.cemrl_td3_workflow import CEMRLTD3WorkflowTemplate
28from evorl.algorithms.erl.cemrl_workflow import CEMRLTrainMetric
29
30
[docs]
31class PopEpisodicTD3Workflow(CEMRLTD3WorkflowTemplate):
32 """A batched TD3 workflow like CEMERL.
33
34 The differences from CEMRL are:
35 - Each individual has an actor and a critic.
36 - All individuals are updated by RL.
37 """
38
39 def __init__(self, **kwargs):
40 super(CEMRLTD3WorkflowTemplate, self).__init__(**kwargs)
41 self._rl_update_fn = build_rl_update_fn(
42 self.agent,
43 self.optimizer,
44 self.config,
45 self.agent_state_vmap_axes,
46 )
47
[docs]
48 @classmethod
49 def name(cls):
50 return "PopEpisodicTD3"
51
52 @classmethod
53 def _build_from_config(cls, config: DictConfig) -> Self:
54 assert config.random_timesteps > 0, (
55 "random_timesteps should be positive to pre-fill some data in the replay buffer"
56 )
57
58 assert config.pop_size == config.num_learning_offspring, (
59 "pop_size must equal to num_learning_offspring"
60 )
61
62 # env for one actor
63 env = create_env(
64 config.env,
65 episode_length=config.env.max_episode_steps,
66 parallel=config.num_envs,
67 autoreset_mode=AutoresetMode.DISABLED,
68 record_ori_obs=True,
69 )
70
71 agent = make_mlp_td3_agent(
72 action_space=env.action_space,
73 norm_layer_type=config.agent_network.norm_layer_type,
74 num_critics=config.agent_network.num_critics,
75 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
76 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
77 discount=config.discount,
78 exploration_epsilon=config.exploration_epsilon,
79 policy_noise=config.policy_noise,
80 clip_policy_noise=config.clip_policy_noise,
81 critics_in_actor_loss=config.critics_in_actor_loss,
82 normalize_obs=config.normalize_obs,
83 )
84
85 if (
86 config.optimizer.grad_clip_norm is not None
87 and config.optimizer.grad_clip_norm > 0
88 ):
89 optimizer = optax.chain(
90 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
91 optax.adam(config.optimizer.lr),
92 )
93 else:
94 optimizer = optax.adam(config.optimizer.lr)
95
96 if config.fitness_with_exploration:
97 action_fn = agent.compute_actions
98 else:
99 action_fn = agent.evaluate_actions
100
101 collector = EpisodeCollector(
102 env=env,
103 action_fn=action_fn,
104 max_episode_steps=config.env.max_episode_steps,
105 env_extra_fields=("ori_obs", "termination"),
106 )
107
108 replay_buffer = ReplayBuffer(
109 capacity=config.replay_buffer_capacity,
110 min_sample_timesteps=config.batch_size,
111 sample_batch_size=config.batch_size,
112 )
113
114 # to evaluate the pop-mean actor
115 eval_env = create_env(
116 config.env,
117 episode_length=config.env.max_episode_steps,
118 parallel=config.num_eval_envs,
119 autoreset_mode=AutoresetMode.DISABLED,
120 )
121
122 evaluator = Evaluator(
123 env=eval_env,
124 action_fn=agent.evaluate_actions,
125 max_episode_steps=config.env.max_episode_steps,
126 )
127
128 agent_state_vmap_axes = AgentState(
129 params=0,
130 obs_preprocessor_state=None, # shared
131 )
132
133 workflow = cls(
134 env=env,
135 agent=agent,
136 agent_state_vmap_axes=agent_state_vmap_axes,
137 optimizer=optimizer,
138 ec_optimizer=None,
139 collector=collector,
140 evaluator=evaluator,
141 replay_buffer=replay_buffer,
142 config=config,
143 )
144
145 return workflow
146
147 def _setup_agent_and_optimizer(self, key: chex.PRNGKey):
148 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))(
149 self.env.obs_space,
150 self.env.action_space,
151 jax.random.split(key, self.config.pop_size),
152 )
153
154 opt_state = PyTreeDict(
155 actor=self.optimizer.init(agent_state.params.actor_params),
156 critic=self.optimizer.init(agent_state.params.critic_params),
157 )
158
159 ec_opt_state = None
160
161 return agent_state, opt_state, ec_opt_state
162
[docs]
163 def step(self, state: State) -> tuple[MetricBase, State]:
164 pop_size = self.config.pop_size
165 agent_state = state.agent_state
166 opt_state = state.opt_state
167 ec_opt_state = state.ec_opt_state
168 replay_buffer_state = state.replay_buffer_state
169 iterations = state.metrics.iterations + 1
170
171 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
172
173 # ======== RL update ========
174 td3_metrics, agent_state, opt_state = self._rl_update(
175 agent_state,
176 opt_state,
177 replay_buffer_state,
178 learn_key,
179 )
180
181 # the trajectory [T, #pop*B, ...]
182 # metrics: [#pop, B]
183 eval_metrics, trajectory, replay_buffer_state = self._rollout(
184 agent_state, replay_buffer_state, rollout_key
185 )
186
187 train_metrics = CEMRLTrainMetric(
188 rb_size=replay_buffer_state.buffer_size,
189 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1),
190 pop_episode_returns=eval_metrics.episode_returns.mean(-1),
191 rl_metrics=td3_metrics,
192 )
193
194 # calculate the number of timestep
195 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32)
196 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
197
198 workflow_metrics = state.metrics.replace(
199 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
200 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
201 iterations=iterations,
202 )
203
204 state = state.replace(
205 key=key,
206 metrics=workflow_metrics,
207 agent_state=agent_state,
208 replay_buffer_state=replay_buffer_state,
209 ec_opt_state=ec_opt_state,
210 opt_state=opt_state,
211 )
212
213 return train_metrics, state
214
[docs]
215 def evaluate(self, state: State) -> tuple[MetricBase, State]:
216 key, eval_key = jax.random.split(state.key, num=2)
217
218 # [#pop, #episodes]
219 raw_eval_metrics = jax.vmap(
220 partial(self.evaluator.evaluate, num_episodes=self.config.eval_episodes),
221 in_axes=(self.agent_state_vmap_axes, 0),
222 )(
223 state.agent_state,
224 jax.random.split(eval_key, self.config.pop_size),
225 )
226
227 eval_metrics = EvaluateMetric(
228 episode_returns=raw_eval_metrics.episode_returns.mean(-1),
229 episode_lengths=raw_eval_metrics.episode_lengths.mean(-1),
230 )
231
232 state = state.replace(key=key)
233 return eval_metrics, state
234
[docs]
235 def learn(self, state: State) -> State:
236 num_iters = math.ceil(
237 (self.config.total_episodes - state.metrics.sampled_episodes)
238 / (self.config.episodes_for_fitness * self.config.pop_size)
239 )
240
241 final_iteration = num_iters + state.metrics.iterations
242 for i in range(state.metrics.iterations, final_iteration):
243 iters = i + 1
244 train_metrics, state = self.step(state)
245 workflow_metrics = state.metrics
246
247 workflow_metrics_dict = workflow_metrics.to_local_dict()
248 self.recorder.write(workflow_metrics_dict, iters)
249
250 train_metrics_dict = train_metrics.to_local_dict()
251 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
252 train_metrics_dict["pop_episode_returns"], histogram=True
253 )
254
255 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
256 train_metrics_dict["pop_episode_lengths"], histogram=True
257 )
258
259 if train_metrics_dict["rl_metrics"] is not None:
260 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map(
261 get_1d_array_statistics,
262 train_metrics_dict["rl_metrics"]["raw_loss_dict"],
263 )
264
265 self.recorder.write(train_metrics_dict, iters)
266
267 if iters % self.config.eval_interval == 0 or iters == final_iteration:
268 eval_metrics, state = self.evaluate(state)
269
270 eval_metrics_dict = jtu.tree_map(
271 get_1d_array,
272 eval_metrics.to_local_dict(),
273 )
274
275 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
276
277 saved_state = state
278 if not self.config.save_replay_buffer:
279 saved_state = skip_replay_buffer_state(saved_state)
280
281 self.checkpoint_manager.save(
282 iters,
283 saved_state,
284 force=iters == final_iteration,
285 )
286
287 return state
288
289
[docs]
290def build_rl_update_fn(
291 agent: Agent,
292 optimizer: optax.GradientTransformation,
293 config: DictConfig,
294 agent_state_vmap_axes: AgentState,
295):
296 """K actors + 1 shared critic."""
297 num_learning_offspring = config.num_learning_offspring
298
299 def critic_loss_fn(agent_state, sample_batch, key):
300 # loss on a single critic with multiple actors
301 # sample_batch: (n, B, ...)
302
303 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
304 agent_state, sample_batch, jax.random.split(key, num_learning_offspring)
305 )
306
307 # ++++++++++ diff +++++++++
308 loss = loss_dict.critic_loss.sum()
309 # +++++++++++++++++++++++++
310
311 return loss, loss_dict
312
313 def actor_loss_fn(agent_state, sample_batch, key):
314 # loss on a single actor
315
316 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
317 agent_state, sample_batch, jax.random.split(key, num_learning_offspring)
318 )
319
320 # sum over the num_learning_offspring
321 loss = loss_dict.actor_loss.sum()
322
323 return loss, loss_dict
324
325 critic_update_fn = agent_gradient_update(
326 critic_loss_fn,
327 optimizer,
328 has_aux=True,
329 attach_fn=lambda agent_state, critic_params: agent_state.replace(
330 params=agent_state.params.replace(critic_params=critic_params)
331 ),
332 detach_fn=lambda agent_state: agent_state.params.critic_params,
333 )
334
335 actor_update_fn = agent_gradient_update(
336 actor_loss_fn,
337 optimizer,
338 has_aux=True,
339 attach_fn=lambda agent_state, actor_params: agent_state.replace(
340 params=agent_state.params.replace(actor_params=actor_params)
341 ),
342 detach_fn=lambda agent_state: agent_state.params.actor_params,
343 )
344
345 def _update_fn(agent_state, opt_state, sample_batches, key):
346 critic_opt_state = opt_state.critic
347 actor_opt_state = opt_state.actor
348
349 key, critic_key, actor_key = jax.random.split(key, num=3)
350
351 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches)
352 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches)
353
354 if config.actor_update_interval - 1 > 0:
355
356 def _update_critic_fn(carry, sample_batch):
357 key, agent_state, critic_opt_state = carry
358
359 key, critic_key = jax.random.split(key)
360
361 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
362 critic_update_fn(
363 critic_opt_state, agent_state, sample_batch, critic_key
364 )
365 )
366
367 return (key, agent_state, critic_opt_state), None
368
369 key, critic_multiple_update_key = jax.random.split(key)
370
371 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
372 _update_critic_fn,
373 (
374 critic_multiple_update_key,
375 agent_state,
376 critic_opt_state,
377 ),
378 critic_sample_batches,
379 length=config.actor_update_interval - 1,
380 )
381
382 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
383 critic_update_fn(
384 critic_opt_state, agent_state, last_sample_batch, critic_key
385 )
386 )
387
388 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn(
389 actor_opt_state, agent_state, last_sample_batch, actor_key
390 )
391
392 # not need vmap
393 target_actor_params = soft_target_update(
394 agent_state.params.target_actor_params,
395 agent_state.params.actor_params,
396 config.tau,
397 )
398 target_critic_params = soft_target_update(
399 agent_state.params.target_critic_params,
400 agent_state.params.critic_params,
401 config.tau,
402 )
403 agent_state = agent_state.replace(
404 params=agent_state.params.replace(
405 target_actor_params=target_actor_params,
406 target_critic_params=target_critic_params,
407 )
408 )
409
410 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state)
411
412 return (
413 (agent_state, opt_state),
414 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
415 )
416
417 return _update_fn