1import logging
2import math
3
4import jax
5import jax.numpy as jnp
6import jax.tree_util as jtu
7
8from evorl.distributed import psum
9from evorl.distributed.gradients import agent_gradient_update
10from evorl.metrics import MetricBase
11from evorl.rollout import rollout
12from evorl.types import (
13 PyTreeDict,
14 State,
15)
16from evorl.utils import running_statistics
17from evorl.utils.jax_utils import tree_stop_gradient
18from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
19from evorl.recorders import add_prefix
20
21from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
22from evorl.algorithms.td3 import TD3TrainMetric, TD3Workflow
23
24logger = logging.getLogger(__name__)
25
26MISSING_LOSS = -1e10
27
28
[docs]
29class TD3V2Workflow(TD3Workflow):
30 """The similar impl of TD3 in SB3 and CleanRL."""
31
[docs]
32 @classmethod
33 def name(cls):
34 return "TD3-V2"
35
[docs]
36 def step(self, state: State) -> tuple[MetricBase, State]:
37 iterations = state.metrics.iterations + 1
38 key, rollout_key, rb_key, critic_key, actor_key = jax.random.split(
39 state.key, num=5
40 )
41
42 # the trajectory [T, B, ...]
43 trajectory, env_state = rollout(
44 env_fn=self.env.step,
45 action_fn=self.agent.compute_actions,
46 env_state=state.env_state,
47 agent_state=state.agent_state,
48 key=rollout_key,
49 rollout_length=self.config.rollout_length,
50 env_extra_fields=("ori_obs", "termination"),
51 )
52
53 trajectory = clean_trajectory(trajectory)
54 trajectory = flatten_rollout_trajectory(trajectory)
55 trajectory = tree_stop_gradient(trajectory)
56
57 agent_state = state.agent_state
58 opt_state = state.opt_state
59
60 if agent_state.obs_preprocessor_state is not None:
61 agent_state = agent_state.replace(
62 obs_preprocessor_state=running_statistics.update(
63 agent_state.obs_preprocessor_state,
64 trajectory.obs,
65 dp_axis_name=self.dp_axis_name,
66 )
67 )
68
69 replay_buffer_state = self.replay_buffer.add(
70 state.replay_buffer_state, trajectory
71 )
72
73 def critic_loss_fn(agent_state, sample_batch, key):
74 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
75
76 loss = loss_dict.critic_loss
77 return loss, loss_dict
78
79 def actor_loss_fn(agent_state, sample_batch, key):
80 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
81
82 loss = loss_dict.actor_loss
83 return loss, loss_dict
84
85 critic_update_fn = agent_gradient_update(
86 critic_loss_fn,
87 self.optimizer,
88 dp_axis_name=self.dp_axis_name,
89 has_aux=True,
90 attach_fn=lambda agent_state, critic_params: agent_state.replace(
91 params=agent_state.params.replace(critic_params=critic_params)
92 ),
93 detach_fn=lambda agent_state: agent_state.params.critic_params,
94 )
95
96 actor_update_fn = agent_gradient_update(
97 actor_loss_fn,
98 self.optimizer,
99 dp_axis_name=self.dp_axis_name,
100 has_aux=True,
101 attach_fn=lambda agent_state, actor_params: agent_state.replace(
102 params=agent_state.params.replace(actor_params=actor_params)
103 ),
104 detach_fn=lambda agent_state: agent_state.params.actor_params,
105 )
106
107 def _update_critic_fn(agent_state, opt_state, sample_batch, key):
108 critic_opt_state = opt_state.critic
109
110 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
111 critic_update_fn(critic_opt_state, agent_state, sample_batch, key)
112 )
113
114 opt_state = opt_state.replace(critic=critic_opt_state)
115
116 return (
117 critic_loss,
118 critic_loss_dict,
119 agent_state,
120 opt_state,
121 )
122
123 def _update_actor_fn(agent_state, opt_state, sample_batch, key):
124 actor_opt_state = opt_state.actor
125
126 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
127 actor_update_fn(actor_opt_state, agent_state, sample_batch, key)
128 )
129
130 target_actor_params = soft_target_update(
131 agent_state.params.target_actor_params,
132 agent_state.params.actor_params,
133 self.config.tau,
134 )
135 target_critic_params = soft_target_update(
136 agent_state.params.target_critic_params,
137 agent_state.params.critic_params,
138 self.config.tau,
139 )
140 agent_state = agent_state.replace(
141 params=agent_state.params.replace(
142 target_actor_params=target_actor_params,
143 target_critic_params=target_critic_params,
144 )
145 )
146
147 opt_state = opt_state.replace(actor=actor_opt_state)
148
149 return (
150 actor_loss,
151 actor_loss_dict,
152 agent_state,
153 opt_state,
154 )
155
156 def _dummy_update_actor_fn(agent_state, opt_state, sample_batch, key):
157 actor_loss = jnp.full((), fill_value=MISSING_LOSS)
158 actor_loss_dict = PyTreeDict(actor_loss=actor_loss)
159
160 return (
161 actor_loss,
162 actor_loss_dict,
163 agent_state,
164 opt_state,
165 )
166
167 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
168
169 critic_loss, critic_loss_dict, agent_state, opt_state = _update_critic_fn(
170 agent_state, opt_state, sample_batch, critic_key
171 )
172
173 # Note: using cond prohibits the parallel training by vmap
174 (
175 actor_loss,
176 actor_loss_dict,
177 agent_state,
178 opt_state,
179 ) = jax.lax.cond(
180 iterations % self.config.actor_update_interval == 0,
181 _update_actor_fn,
182 _dummy_update_actor_fn,
183 agent_state,
184 opt_state,
185 sample_batch,
186 actor_key,
187 )
188
189 train_metrics = TD3TrainMetric(
190 actor_loss=actor_loss,
191 critic_loss=critic_loss,
192 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
193 ).all_reduce(dp_axis_name=self.dp_axis_name)
194
195 # calculate the number of timestep
196 sampled_timesteps = psum(
197 jnp.uint32(self.config.rollout_length * self.config.num_envs),
198 axis_name=self.dp_axis_name,
199 )
200
201 # iterations is the number of updates of the agent
202 workflow_metrics = state.metrics.replace(
203 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
204 iterations=state.metrics.iterations + 1,
205 ).all_reduce(dp_axis_name=self.dp_axis_name)
206
207 return train_metrics, state.replace(
208 key=key,
209 metrics=workflow_metrics,
210 agent_state=agent_state,
211 env_state=env_state,
212 replay_buffer_state=replay_buffer_state,
213 opt_state=opt_state,
214 )
215
[docs]
216 def learn(self, state: State) -> State:
217 num_devices = jax.device_count()
218 one_step_timesteps = self.config.rollout_length * self.config.num_envs
219 sampled_timesteps = state.metrics.sampled_timesteps.tolist()
220 num_iters = math.ceil(
221 (self.config.total_timesteps - sampled_timesteps)
222 / (one_step_timesteps * self.config.fold_iters * num_devices)
223 )
224
225 start_iteration = state.metrics.iterations.tolist()
226 final_iteration = num_iters + start_iteration
227
228 for i in range(num_iters):
229 train_metrics, state = self._multi_steps(state)
230 workflow_metrics = state.metrics
231
232 # current iteration
233 iterations = state.metrics.iterations.tolist()
234
235 train_metrics = jtu.tree_map(
236 lambda x: None if x == MISSING_LOSS else x, train_metrics
237 )
238
239 self.recorder.write(train_metrics.to_local_dict(), iterations)
240 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
241
242 if (
243 iterations % self.config.eval_interval == 0
244 or iterations == final_iteration
245 ):
246 eval_metrics, state = self.evaluate(state)
247 self.recorder.write(
248 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
249 )
250
251 saved_state = state
252 if not self.config.save_replay_buffer:
253 saved_state = skip_replay_buffer_state(saved_state)
254 self.checkpoint_manager.save(
255 iterations, saved_state, force=iterations == final_iteration
256 )
257
258 return state