1import jax
2import jax.numpy as jnp
3import jax.tree_util as jtu
4import optax
5
6from omegaconf import DictConfig
7from evorl.agent import Agent, AgentState
8
9from evorl.distributed import agent_gradient_update
10from evorl.types import PyTreeDict
11from evorl.utils.jax_utils import (
12 right_shift_with_padding,
13 tree_stop_gradient,
14 scan_and_mean,
15)
16from evorl.utils.rl_toolkits import (
17 flatten_rollout_trajectory,
18 flatten_pop_rollout_episode,
19 soft_target_update,
20)
21from evorl.algorithms.td3 import TD3TrainMetric, TD3NetworkParams
22
23from ..cemrl_workflow import CEMRLWorkflowBase
24
25
[docs]
26class CEMRLTD3WorkflowTemplate(CEMRLWorkflowBase):
27 """A template for ERL workflow on TD3 Agent."""
28
29 def __init__(self, **kwargs):
30 super().__init__(**kwargs)
31 self._rl_update_fn = build_cemrl_rl_update_fn(
32 self.agent,
33 self.optimizer,
34 self.config,
35 self.agent_state_vmap_axes,
36 )
37
38 def _rollout(self, pop_agent_state, replay_buffer_state, key):
39 return rollout_episode(
40 pop_agent_state,
41 replay_buffer_state,
42 key,
43 collector=self.collector,
44 replay_buffer=self.replay_buffer,
45 agent_state_vmap_axes=self.agent_state_vmap_axes,
46 num_episodes=self.config.episodes_for_fitness,
47 num_agents=self.config.pop_size,
48 )
49
50 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key):
51 def _sample_fn(key):
52 return self.replay_buffer.sample(replay_buffer_state, key)
53
54 def _sample_and_update_fn(carry, unused_t):
55 key, agent_state, opt_state = carry
56
57 key, rb_key, learn_key = jax.random.split(key, 3)
58
59 rb_keys = jax.random.split(
60 rb_key,
61 self.config.actor_update_interval * self.config.num_learning_offspring,
62 )
63 sample_batches = jax.vmap(_sample_fn)(rb_keys)
64
65 # (actor_update_interval, num_learning_offspring, B, ...)
66 sample_batches = jtu.tree_map(
67 lambda x: x.reshape(
68 (
69 self.config.actor_update_interval,
70 self.config.num_learning_offspring,
71 *x.shape[1:],
72 )
73 ),
74 sample_batches,
75 )
76
77 (agent_state, opt_state), train_info = self._rl_update_fn(
78 agent_state, opt_state, sample_batches, learn_key
79 )
80
81 return (key, agent_state, opt_state), train_info
82
83 (
84 (_, agent_state, opt_state),
85 (
86 critic_loss,
87 actor_loss,
88 critic_loss_dict,
89 actor_loss_dict,
90 ),
91 ) = scan_and_mean(
92 _sample_and_update_fn,
93 (key, agent_state, opt_state),
94 (),
95 length=self.config.num_rl_updates_per_iter,
96 )
97
98 # smoothed td3 metrics
99 td3_metrics = TD3TrainMetric(
100 actor_loss=actor_loss,
101 critic_loss=critic_loss,
102 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
103 )
104
105 td3_metrics = td3_metrics.replace(
106 actor_loss=td3_metrics.actor_loss / self.config.num_learning_offspring
107 )
108
109 return td3_metrics, agent_state, opt_state
110
111
[docs]
112def cemrl_replace_td3_actor_params(
113 agent_state: AgentState, pop_actor_params: TD3NetworkParams
114) -> AgentState:
115 # Keep the critic params unchanged.
116
117 return agent_state.replace(
118 params=agent_state.params.replace(
119 actor_params=pop_actor_params,
120 target_actor_params=pop_actor_params,
121 )
122 )
123
124
125DUMMY_TD3_TRAINMETRIC = TD3TrainMetric(
126 critic_loss=jnp.zeros(()),
127 actor_loss=jnp.zeros(()),
128 raw_loss_dict=PyTreeDict(
129 critic_loss=jnp.zeros(()),
130 q_value=jnp.zeros(()),
131 actor_loss=jnp.zeros(()),
132 ),
133)
134
135
[docs]
136def create_dummy_td3_trainmetric(num: int) -> TD3TrainMetric:
137 if num >= 1:
138 return DUMMY_TD3_TRAINMETRIC.replace(
139 raw_loss_dict=jtu.tree_map(
140 lambda x: jnp.broadcast_to(x, (num, *x.shape)),
141 DUMMY_TD3_TRAINMETRIC.raw_loss_dict,
142 )
143 )
144 else:
145 raise ValueError(f"num should be positive, got {num}")
146
147
[docs]
148def rollout_episode(
149 agent_state: AgentState,
150 replay_buffer_state,
151 key,
152 *,
153 collector,
154 replay_buffer,
155 agent_state_vmap_axes,
156 num_episodes,
157 num_agents,
158):
159 eval_metrics, trajectory = jax.vmap(
160 collector.rollout,
161 in_axes=(agent_state_vmap_axes, 0, None),
162 )(
163 agent_state,
164 jax.random.split(key, num_agents),
165 num_episodes,
166 )
167
168 # [n, T, B, ...] -> [T, n*B, ...]
169 trajectory = trajectory.replace(next_obs=None)
170 trajectory = flatten_pop_rollout_episode(trajectory)
171
172 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1))
173 trajectory = trajectory.replace(dones=None)
174 trajectory, mask = tree_stop_gradient(
175 flatten_rollout_trajectory((trajectory, mask))
176 )
177 replay_buffer_state = replay_buffer.add(replay_buffer_state, trajectory, mask)
178
179 return eval_metrics, trajectory, replay_buffer_state
180
181
[docs]
182def build_cemrl_rl_update_fn(
183 agent: Agent,
184 optimizer: optax.GradientTransformation,
185 config: DictConfig,
186 agent_state_vmap_axes: AgentState,
187):
188 """K actors + 1 shared critic."""
189 num_learning_offspring = config.num_learning_offspring
190
191 def critic_loss_fn(agent_state, sample_batch, key):
192 # loss on a single critic with multiple actors
193 # sample_batch: (n, B, ...)
194
195 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
196 agent_state, sample_batch, jax.random.split(key, num_learning_offspring)
197 )
198
199 # mean over the num_learning_offspring
200 loss = loss_dict.critic_loss.mean()
201
202 return loss, loss_dict
203
204 def actor_loss_fn(agent_state, sample_batch, key):
205 # loss on a single actor
206
207 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))(
208 agent_state, sample_batch, jax.random.split(key, num_learning_offspring)
209 )
210
211 # sum over the num_learning_offspring
212 loss = loss_dict.actor_loss.sum()
213
214 return loss, loss_dict
215
216 critic_update_fn = agent_gradient_update(
217 critic_loss_fn,
218 optimizer,
219 has_aux=True,
220 attach_fn=lambda agent_state, critic_params: agent_state.replace(
221 params=agent_state.params.replace(critic_params=critic_params)
222 ),
223 detach_fn=lambda agent_state: agent_state.params.critic_params,
224 )
225
226 actor_update_fn = agent_gradient_update(
227 actor_loss_fn,
228 optimizer,
229 has_aux=True,
230 attach_fn=lambda agent_state, actor_params: agent_state.replace(
231 params=agent_state.params.replace(actor_params=actor_params)
232 ),
233 detach_fn=lambda agent_state: agent_state.params.actor_params,
234 )
235
236 def _update_fn(agent_state, opt_state, sample_batches, key):
237 critic_opt_state = opt_state.critic
238 actor_opt_state = opt_state.actor
239
240 key, critic_key, actor_key = jax.random.split(key, num=3)
241
242 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches)
243 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches)
244
245 if config.actor_update_interval - 1 > 0:
246
247 def _update_critic_fn(carry, sample_batch):
248 key, agent_state, critic_opt_state = carry
249
250 key, critic_key = jax.random.split(key)
251
252 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
253 critic_update_fn(
254 critic_opt_state, agent_state, sample_batch, critic_key
255 )
256 )
257
258 return (key, agent_state, critic_opt_state), None
259
260 key, critic_multiple_update_key = jax.random.split(key)
261
262 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
263 _update_critic_fn,
264 (
265 critic_multiple_update_key,
266 agent_state,
267 critic_opt_state,
268 ),
269 critic_sample_batches,
270 length=config.actor_update_interval - 1,
271 )
272
273 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
274 critic_update_fn(
275 critic_opt_state, agent_state, last_sample_batch, critic_key
276 )
277 )
278
279 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn(
280 actor_opt_state, agent_state, last_sample_batch, actor_key
281 )
282
283 # not need vmap
284 target_actor_params = soft_target_update(
285 agent_state.params.target_actor_params,
286 agent_state.params.actor_params,
287 config.tau,
288 )
289 target_critic_params = soft_target_update(
290 agent_state.params.target_critic_params,
291 agent_state.params.critic_params,
292 config.tau,
293 )
294 agent_state = agent_state.replace(
295 params=agent_state.params.replace(
296 target_actor_params=target_actor_params,
297 target_critic_params=target_critic_params,
298 )
299 )
300
301 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state)
302
303 return (
304 (agent_state, opt_state),
305 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
306 )
307
308 return _update_fn