1import logging
2from typing import Any
3
4import chex
5import flax.linen as nn
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10from omegaconf import DictConfig
11
12from evorl.distributed import psum, pmean
13from evorl.distributed.gradients import agent_gradient_update
14from evorl.envs import AutoresetMode, Box, create_env, Space
15from evorl.evaluators import Evaluator
16from evorl.metrics import MetricBase, metric_field
17from evorl.networks import make_policy_network, make_q_network
18from evorl.rollout import rollout
19from evorl.replay_buffers import ReplayBuffer
20from evorl.sample_batch import SampleBatch
21from evorl.types import (
22 Action,
23 LossDict,
24 Params,
25 PolicyExtraInfo,
26 PyTreeData,
27 PyTreeDict,
28 State,
29 pytree_field,
30)
31from evorl.utils import running_statistics
32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
34
35from evorl.agent import Agent, AgentState
36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
37
38logger = logging.getLogger(__name__)
39
40
[docs]
41class DDPGTrainMetric(MetricBase):
42 actor_loss: chex.Array
43 critic_loss: chex.Array
44 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
45
46
[docs]
47class DDPGNetworkParams(PyTreeData):
48 """Contains training state for the learner."""
49
50 actor_params: Params
51 critic_params: Params
52
53 target_actor_params: Params
54 target_critic_params: Params
55
56
[docs]
57class DDPGAgent(Agent):
58 """The Agnet for DDPG."""
59
60 critic_network: nn.Module
61 actor_network: nn.Module
62 obs_preprocessor: Any = pytree_field(default=None, static=True)
63
64 discount: float = 1
65 exploration_epsilon: float = 0.5
66
67 @property
68 def normalize_obs(self):
69 return self.obs_preprocessor is not None
70
[docs]
71 def init(
72 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
73 ) -> AgentState:
74 key, q_key, actor_key = jax.random.split(key, num=3)
75
76 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
77 dummy_action = action_space.sample(key)[None, ...]
78
79 critic_params = self.critic_network.init(q_key, dummy_obs, dummy_action)
80 target_critic_params = critic_params
81
82 actor_params = self.actor_network.init(actor_key, dummy_obs)
83 target_actor_params = actor_params
84
85 params_state = DDPGNetworkParams(
86 critic_params=critic_params,
87 actor_params=actor_params,
88 target_critic_params=target_critic_params,
89 target_actor_params=target_actor_params,
90 )
91
92 if self.normalize_obs:
93 # Note: statistics are broadcasted to [T*B]
94 obs_preprocessor_state = running_statistics.init_state(
95 tree_get(dummy_obs, 0)
96 )
97 else:
98 obs_preprocessor_state = None
99
100 return AgentState(
101 params=params_state, obs_preprocessor_state=obs_preprocessor_state
102 )
103
[docs]
104 def compute_actions(
105 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
106 ) -> tuple[Action, PolicyExtraInfo]:
107 obs = sample_batch.obs
108 if self.normalize_obs:
109 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
110
111 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
112 # add random noise
113 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon
114 actions += noise
115 actions = jnp.clip(actions, -1.0, 1.0)
116
117 return actions, PyTreeDict()
118
[docs]
119 def evaluate_actions(
120 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
121 ) -> tuple[Action, PolicyExtraInfo]:
122 obs = sample_batch.obs
123 if self.normalize_obs:
124 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
125
126 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
127
128 return actions, PyTreeDict()
129
[docs]
130 def critic_loss(
131 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
132 ) -> LossDict:
133 next_obs = sample_batch.extras.env_extras.ori_obs
134 obs = sample_batch.obs
135 actions = sample_batch.actions
136
137 if self.normalize_obs:
138 next_obs = self.obs_preprocessor(
139 next_obs, agent_state.obs_preprocessor_state
140 )
141 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
142
143 next_actions = self.actor_network.apply(
144 agent_state.params.target_actor_params, next_obs
145 )
146
147 next_qs = self.critic_network.apply(
148 agent_state.params.target_critic_params, next_obs, next_actions
149 )
150
151 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
152
153 qs_target = sample_batch.rewards + discounts * next_qs
154 qs_target = jax.lax.stop_gradient(qs_target)
155
156 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions)
157
158 # q_loss = optax.huber_loss(qs, target_qs, delta=1).mean()
159 q_loss = optax.squared_error(qs, qs_target).mean()
160
161 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
162
[docs]
163 def actor_loss(
164 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
165 ) -> LossDict:
166 obs = sample_batch.obs
167
168 if self.normalize_obs:
169 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
170
171 # [T*B, A]
172 actions = self.actor_network.apply(agent_state.params.actor_params, obs)
173
174 actor_loss = -jnp.mean(
175 self.critic_network.apply(agent_state.params.critic_params, obs, actions)
176 )
177 return PyTreeDict(actor_loss=actor_loss)
178
179
[docs]
180def make_mlp_ddpg_agent(
181 action_space: Space,
182 critic_hidden_layer_sizes: tuple[int] = (256, 256),
183 actor_hidden_layer_sizes: tuple[int] = (256, 256),
184 discount: float = 1,
185 exploration_epsilon: float = 0.5,
186 normalize_obs: bool = False,
187 policy_obs_key: str = "",
188 value_obs_key: str = "",
189):
190 assert isinstance(action_space, Box), "Only continue action space is supported."
191
192 action_size = action_space.shape[0]
193
194 critic_network = make_q_network(
195 hidden_layer_sizes=critic_hidden_layer_sizes,
196 obs_key=value_obs_key,
197 )
198 actor_network = make_policy_network(
199 action_size=action_size,
200 hidden_layer_sizes=actor_hidden_layer_sizes,
201 activation_final=nn.tanh,
202 obs_key=policy_obs_key,
203 )
204
205 if normalize_obs:
206 obs_preprocessor = running_statistics.normalize
207 else:
208 obs_preprocessor = None
209
210 return DDPGAgent(
211 critic_network=critic_network,
212 actor_network=actor_network,
213 obs_preprocessor=obs_preprocessor,
214 discount=discount,
215 exploration_epsilon=exploration_epsilon,
216 )
217
218
[docs]
219class DDPGWorkflow(OffPolicyWorkflowTemplate):
[docs]
220 @classmethod
221 def name(cls):
222 return "DDPG"
223
224 @classmethod
225 def _build_from_config(cls, config: DictConfig):
226 env = create_env(
227 config.env,
228 episode_length=config.env.max_episode_steps,
229 parallel=config.num_envs,
230 autoreset_mode=AutoresetMode.NORMAL,
231 record_ori_obs=True,
232 )
233
234 agent = make_mlp_ddpg_agent(
235 action_space=env.action_space,
236 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
237 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
238 discount=config.discount,
239 exploration_epsilon=config.exploration_epsilon,
240 normalize_obs=config.normalize_obs,
241 policy_obs_key=config.agent_network.policy_obs_key,
242 value_obs_key=config.agent_network.value_obs_key,
243 )
244
245 # one optimizer, two opt_states (in setup function) for both actor and critic
246 if (
247 config.optimizer.grad_clip_norm is not None
248 and config.optimizer.grad_clip_norm > 0
249 ):
250 optimizer = optax.chain(
251 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
252 optax.adam(config.optimizer.lr),
253 )
254 else:
255 optimizer = optax.adam(config.optimizer.lr)
256
257 replay_buffer = ReplayBuffer(
258 capacity=config.replay_buffer_capacity,
259 min_sample_timesteps=max(
260 config.batch_size, config.learning_start_timesteps
261 ),
262 sample_batch_size=config.batch_size,
263 )
264
265 eval_env = create_env(
266 config.env,
267 episode_length=config.env.max_episode_steps,
268 parallel=config.num_eval_envs,
269 autoreset_mode=AutoresetMode.DISABLED,
270 )
271
272 evaluator = Evaluator(
273 env=eval_env,
274 action_fn=agent.evaluate_actions,
275 max_episode_steps=config.env.max_episode_steps,
276 )
277
278 return cls(
279 env,
280 agent,
281 optimizer,
282 evaluator,
283 replay_buffer,
284 config,
285 )
286
287 def _setup_agent_and_optimizer(
288 self, key: chex.PRNGKey
289 ) -> tuple[AgentState, chex.ArrayTree]:
290 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
291 opt_state = PyTreeDict(
292 dict(
293 actor=self.optimizer.init(agent_state.params.actor_params),
294 critic=self.optimizer.init(agent_state.params.critic_params),
295 )
296 )
297 return agent_state, opt_state
298
[docs]
299 def step(self, state: State) -> tuple[MetricBase, State]:
300 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
301
302 # the trajectory [T, B, ...]
303 trajectory, env_state = rollout(
304 env_fn=self.env.step,
305 action_fn=self.agent.compute_actions,
306 env_state=state.env_state,
307 agent_state=state.agent_state,
308 key=rollout_key,
309 rollout_length=self.config.rollout_length,
310 env_extra_fields=("ori_obs", "termination"),
311 )
312
313 trajectory_dones = trajectory.dones
314 trajectory = clean_trajectory(trajectory)
315 trajectory = flatten_rollout_trajectory(trajectory)
316 trajectory = tree_stop_gradient(trajectory)
317
318 agent_state = state.agent_state
319 if agent_state.obs_preprocessor_state is not None:
320 agent_state = agent_state.replace(
321 obs_preprocessor_state=running_statistics.update(
322 agent_state.obs_preprocessor_state,
323 trajectory.obs,
324 dp_axis_name=self.dp_axis_name,
325 )
326 )
327
328 replay_buffer_state = self.replay_buffer.add(
329 state.replay_buffer_state, trajectory
330 )
331
332 def critic_loss_fn(agent_state, sample_batch, key):
333 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
334
335 loss = loss_dict.critic_loss
336 return loss, loss_dict
337
338 def actor_loss_fn(agent_state, sample_batch, key):
339 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
340
341 loss = loss_dict.actor_loss
342 return loss, loss_dict
343
344 critic_update_fn = agent_gradient_update(
345 critic_loss_fn,
346 self.optimizer,
347 dp_axis_name=self.dp_axis_name,
348 has_aux=True,
349 attach_fn=lambda agent_state, critic_params: agent_state.replace(
350 params=agent_state.params.replace(critic_params=critic_params)
351 ),
352 detach_fn=lambda agent_state: agent_state.params.critic_params,
353 )
354
355 actor_update_fn = agent_gradient_update(
356 actor_loss_fn,
357 self.optimizer,
358 dp_axis_name=self.dp_axis_name,
359 has_aux=True,
360 attach_fn=lambda agent_state, actor_params: agent_state.replace(
361 params=agent_state.params.replace(actor_params=actor_params)
362 ),
363 detach_fn=lambda agent_state: agent_state.params.actor_params,
364 )
365
366 def _sample_and_update_fn(carry, unused_t):
367 key, agent_state, opt_state = carry
368
369 key, rb_key, critic_key, actor_key = jax.random.split(key, 4)
370
371 critic_opt_state = opt_state.critic
372 actor_opt_state = opt_state.actor
373
374 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
375
376 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
377 critic_update_fn(
378 opt_state.critic, agent_state, sample_batch, critic_key
379 )
380 )
381
382 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
383 actor_update_fn(opt_state.actor, agent_state, sample_batch, actor_key)
384 )
385
386 target_actor_params = soft_target_update(
387 agent_state.params.target_actor_params,
388 agent_state.params.actor_params,
389 self.config.tau,
390 )
391 target_critic_params = soft_target_update(
392 agent_state.params.target_critic_params,
393 agent_state.params.critic_params,
394 self.config.tau,
395 )
396 agent_state = agent_state.replace(
397 params=agent_state.params.replace(
398 target_actor_params=target_actor_params,
399 target_critic_params=target_critic_params,
400 )
401 )
402
403 opt_state = PyTreeDict(actor=actor_opt_state, critic=critic_opt_state)
404
405 return (
406 (key, agent_state, opt_state),
407 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict),
408 )
409
410 (
411 (_, agent_state, opt_state),
412 (
413 critic_loss,
414 actor_loss,
415 critic_loss_dict,
416 actor_loss_dict,
417 ),
418 ) = scan_and_mean(
419 _sample_and_update_fn,
420 (learn_key, agent_state, state.opt_state),
421 (),
422 length=self.config.num_updates_per_iter,
423 )
424
425 train_metrics = DDPGTrainMetric(
426 actor_loss=actor_loss,
427 critic_loss=critic_loss,
428 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
429 ).all_reduce(dp_axis_name=self.dp_axis_name)
430
431 # calculate the number of timestep
432 sampled_timesteps = psum(
433 jnp.uint32(self.config.rollout_length * self.config.num_envs),
434 axis_name=self.dp_axis_name,
435 )
436 sampled_epsiodes = psum(
437 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
438 )
439
440 # iterations is the number of updates of the agent
441 workflow_metrics = state.metrics.replace(
442 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
443 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
444 iterations=state.metrics.iterations + 1,
445 ).all_reduce(dp_axis_name=self.dp_axis_name)
446
447 return train_metrics, state.replace(
448 key=key,
449 metrics=workflow_metrics,
450 agent_state=agent_state,
451 env_state=env_state,
452 replay_buffer_state=replay_buffer_state,
453 opt_state=opt_state,
454 )