1import logging
2from functools import partial
3import math
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10from omegaconf import DictConfig
11
12from evorl.distributed import psum
13from evorl.distributed.gradients import agent_gradient_update
14from evorl.envs import AutoresetMode, create_env
15from evorl.evaluators import Evaluator
16from evorl.metrics import MetricBase
17from evorl.rollout import rollout
18from evorl.types import (
19 PyTreeDict,
20 State,
21)
22from evorl.utils import running_statistics
23from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, scan_and_last
24from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
25from evorl.agent import Agent, AgentState
26from evorl.recorders import add_prefix
27from evorl.workflows import OnPolicyWorkflow
28
29from evorl.algorithms.td3 import make_mlp_td3_agent, clean_trajectory, TD3TrainMetric
30
31logger = logging.getLogger(__name__)
32
33MISSING_LOSS = -1e10
34
35
[docs]
36class TD3OnPolicyWorkflow(OnPolicyWorkflow):
[docs]
37 @classmethod
38 def name(cls):
39 return "TD3-OnPolicy"
40
41 @classmethod
42 def _rescale_config(cls, config: DictConfig) -> None:
43 num_devices = jax.device_count()
44
45 if config.num_envs % num_devices != 0:
46 logger.warning(
47 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), "
48 f"rescale num_envs to {config.num_envs // num_devices}"
49 )
50 if config.num_eval_envs % num_devices != 0:
51 logger.warning(
52 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), "
53 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}"
54 )
55 if config.minibatch_size % num_devices != 0:
56 logger.warning(
57 f"minibatch_size({config.minibatch_size}) cannot be divided by num_devices({num_devices}), "
58 f"rescale minibatch_size to {config.minibatch_size // num_devices}"
59 )
60
61 config.num_envs = config.num_envs // num_devices
62 config.num_eval_envs = config.num_eval_envs // num_devices
63 config.minibatch_size = config.minibatch_size // num_devices
64
65 @classmethod
66 def _build_from_config(cls, config: DictConfig):
67 env = create_env(
68 config.env,
69 episode_length=config.env.max_episode_steps,
70 parallel=config.num_envs,
71 autoreset_mode=AutoresetMode.NORMAL,
72 record_ori_obs=True,
73 )
74
75 agent = make_mlp_td3_agent(
76 action_space=env.action_space,
77 norm_layer_type=config.agent_network.norm_layer_type,
78 num_critics=config.agent_network.num_critics,
79 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
80 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
81 discount=config.discount,
82 exploration_epsilon=config.exploration_epsilon,
83 policy_noise=config.policy_noise,
84 clip_policy_noise=config.clip_policy_noise,
85 critics_in_actor_loss=config.critics_in_actor_loss,
86 normalize_obs=config.normalize_obs,
87 )
88
89 if (
90 config.optimizer.grad_clip_norm is not None
91 and config.optimizer.grad_clip_norm > 0
92 ):
93 optimizer = optax.chain(
94 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
95 optax.adam(config.optimizer.lr),
96 )
97 else:
98 optimizer = optax.adam(config.optimizer.lr)
99
100 eval_env = create_env(
101 config.env,
102 episode_length=config.env.max_episode_steps,
103 parallel=config.num_eval_envs,
104 autoreset_mode=AutoresetMode.DISABLED,
105 )
106
107 one_step_rollout_steps = config.num_envs * config.rollout_length
108 if one_step_rollout_steps % config.minibatch_size != 0:
109 logger.warning(
110 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)"
111 )
112
113 evaluator = Evaluator(
114 env=eval_env,
115 action_fn=agent.evaluate_actions,
116 max_episode_steps=config.env.max_episode_steps,
117 )
118
119 return cls(env, agent, optimizer, evaluator, config)
120
121 def _setup_agent_and_optimizer(
122 self, key: chex.PRNGKey
123 ) -> tuple[AgentState, chex.ArrayTree]:
124 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
125 opt_state = PyTreeDict(
126 actor=self.optimizer.init(agent_state.params.actor_params),
127 critic=self.optimizer.init(agent_state.params.critic_params),
128 )
129 return agent_state, opt_state
130
[docs]
131 def step(self, state: State) -> tuple[MetricBase, State]:
132 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
133
134 # trajectory: [T, #envs, ...]
135 trajectory, env_state = rollout(
136 self.env.step,
137 self.agent.compute_actions,
138 state.env_state,
139 state.agent_state,
140 rollout_key,
141 rollout_length=self.config.rollout_length,
142 env_extra_fields=("ori_obs", "termination"),
143 )
144
145 agent_state = state.agent_state
146 if agent_state.obs_preprocessor_state is not None:
147 agent_state = agent_state.replace(
148 obs_preprocessor_state=running_statistics.update(
149 agent_state.obs_preprocessor_state,
150 trajectory.obs,
151 dp_axis_name=self.dp_axis_name,
152 )
153 )
154
155 trajectory = clean_trajectory(trajectory)
156 trajectory = flatten_rollout_trajectory(trajectory)
157 trajectory = tree_stop_gradient(trajectory)
158 # ============================
159
160 update_fn = build_rl_update_fn(self.agent, self.optimizer, self.config)
161
162 num_minibatches = (
163 self.config.rollout_length
164 * self.config.num_envs
165 // (self.config.minibatch_size * self.config.actor_update_interval)
166 )
167
168 def _get_shuffled_minibatch(perm_key, x):
169 x = jax.random.permutation(perm_key, x)[
170 : num_minibatches
171 * self.config.minibatch_size
172 * self.config.actor_update_interval
173 ]
174 return x.reshape(
175 num_minibatches,
176 self.config.actor_update_interval,
177 self.config.minibatch_size,
178 *x.shape[1:],
179 )
180
181 def minibatch_step(carry, trajectory):
182 # trajectory: [actor_update_interval, B, ...]
183
184 opt_state, agent_state, key = carry
185 key, learn_key = jax.random.split(key)
186
187 (agent_state, opt_state), train_info = update_fn(
188 agent_state, opt_state, trajectory, learn_key
189 )
190
191 return (opt_state, agent_state, key), train_info
192
193 def epoch_step(carry, _):
194 opt_state, agent_state, key = carry
195 perm_key, learn_key = jax.random.split(key, num=2)
196
197 (opt_state, agent_state, key), train_info = scan_and_mean(
198 minibatch_step,
199 (opt_state, agent_state, learn_key),
200 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory),
201 length=num_minibatches,
202 )
203
204 return (opt_state, agent_state, key), train_info
205
206 # loss_list: [reuse_rollout_epochs, num_minibatches]
207 (
208 (opt_state, agent_state, _),
209 (
210 critic_loss,
211 actor_loss,
212 critic_loss_dict,
213 actor_loss_dict,
214 ),
215 ) = scan_and_last(
216 epoch_step,
217 (state.opt_state, agent_state, learn_key),
218 None,
219 length=self.config.reuse_rollout_epochs,
220 )
221
222 # ======== update metrics ========
223
224 sampled_timesteps = psum(
225 jnp.uint32(self.config.rollout_length * self.config.num_envs),
226 axis_name=self.dp_axis_name,
227 )
228
229 workflow_metrics = state.metrics.replace(
230 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
231 iterations=state.metrics.iterations + 1,
232 ).all_reduce(dp_axis_name=self.dp_axis_name)
233
234 train_metrics = TD3TrainMetric(
235 actor_loss=actor_loss,
236 critic_loss=critic_loss,
237 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
238 ).all_reduce(dp_axis_name=self.dp_axis_name)
239
240 return train_metrics, state.replace(
241 key=key,
242 metrics=workflow_metrics,
243 agent_state=agent_state,
244 env_state=env_state,
245 opt_state=opt_state,
246 )
247
[docs]
248 def learn(self, state: State) -> State:
249 one_step_timesteps = self.config.rollout_length * self.config.num_envs
250 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps)
251
252 start_iteration = state.metrics.iterations.tolist()
253 final_iteration = num_iters + start_iteration
254
255 for i in range(num_iters):
256 train_metrics, state = self.step(state)
257 workflow_metrics = state.metrics
258
259 iterations = state.metrics.iterations.tolist()
260
261 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
262 self.recorder.write(train_metrics.to_local_dict(), iterations)
263
264 if (
265 iterations % self.config.eval_interval == 0
266 or iterations == final_iteration
267 ):
268 eval_metrics, state = self.evaluate(state)
269 self.recorder.write(
270 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
271 )
272
273 self.checkpoint_manager.save(
274 iterations,
275 state,
276 force=iterations == final_iteration,
277 )
278
279 return state
280
281
[docs]
282def build_rl_update_fn(
283 agent: Agent,
284 optimizer: optax.GradientTransformation,
285 config: DictConfig,
286):
287 def critic_loss_fn(agent_state, sample_batch, key):
288 # loss on a single critic with multiple actors
289 # sample_batch: (B, ...)
290
291 loss_dict = agent.critic_loss(agent_state, sample_batch, key)
292
293 loss = loss_dict.critic_loss
294
295 return loss, loss_dict
296
297 def actor_loss_fn(agent_state, sample_batch, key):
298 # loss on a single actor
299 # different actor shares same sample_batch (B, ...) input
300 loss_dict = agent.actor_loss(agent_state, sample_batch, key)
301
302 loss = loss_dict.actor_loss
303
304 return loss, loss_dict
305
306 critic_update_fn = agent_gradient_update(
307 critic_loss_fn,
308 optimizer,
309 has_aux=True,
310 attach_fn=lambda agent_state, critic_params: agent_state.replace(
311 params=agent_state.params.replace(critic_params=critic_params)
312 ),
313 detach_fn=lambda agent_state: agent_state.params.critic_params,
314 )
315
316 actor_update_fn = agent_gradient_update(
317 actor_loss_fn,
318 optimizer,
319 has_aux=True,
320 attach_fn=lambda agent_state, actor_params: agent_state.replace(
321 params=agent_state.params.replace(actor_params=actor_params)
322 ),
323 detach_fn=lambda agent_state: agent_state.params.actor_params,
324 )
325
326 def _update_fn(agent_state, opt_state, sample_batches, key):
327 critic_opt_state = opt_state.critic
328 actor_opt_state = opt_state.actor
329
330 key, critic_key, actor_key = jax.random.split(key, num=3)
331
332 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches)
333 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches)
334
335 if config.actor_update_interval - 1 > 0:
336
337 def _update_critic_fn(carry, sample_batch):
338 key, agent_state, critic_opt_state = carry
339
340 key, critic_key = jax.random.split(key)
341
342 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
343 critic_update_fn(
344 critic_opt_state, agent_state, sample_batch, critic_key
345 )
346 )
347
348 return (key, agent_state, critic_opt_state), None
349
350 key, critic_multiple_update_key = jax.random.split(key)
351
352 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
353 _update_critic_fn,
354 (
355 critic_multiple_update_key,
356 agent_state,
357 critic_opt_state,
358 ),
359 critic_sample_batches,
360 length=config.actor_update_interval - 1,
361 )
362
363 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
364 critic_update_fn(
365 critic_opt_state, agent_state, last_sample_batch, critic_key
366 )
367 )
368
369 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn(
370 actor_opt_state, agent_state, last_sample_batch, actor_key
371 )
372
373 # not need vmap
374 target_actor_params = soft_target_update(
375 agent_state.params.target_actor_params,
376 agent_state.params.actor_params,
377 config.tau,
378 )
379 target_critic_params = soft_target_update(
380 agent_state.params.target_critic_params,
381 agent_state.params.critic_params,
382 config.tau,
383 )
384 agent_state = agent_state.replace(
385 params=agent_state.params.replace(
386 target_actor_params=target_actor_params,
387 target_critic_params=target_critic_params,
388 )
389 )
390
391 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state)
392
393 return (
394 (agent_state, opt_state),
395 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
396 )
397
398 return _update_fn