1import logging
2import math
3from typing import Any
4
5import chex
6import distrax
7
8import flax.linen as nn
9import jax
10import jax.numpy as jnp
11import jax.tree_util as jtu
12import optax
13from omegaconf import DictConfig
14
15from evorl.replay_buffers import ReplayBuffer
16from evorl.distributed import psum, pmean
17from evorl.distributed.gradients import agent_gradient_update
18from evorl.envs import AutoresetMode, Discrete, create_env, Space
19from evorl.evaluators import Evaluator
20from evorl.metrics import MetricBase, WorkflowMetric, metric_field
21from evorl.networks import make_discrete_q_network
22from evorl.rollout import rollout
23from evorl.sample_batch import SampleBatch
24from evorl.types import (
25 Action,
26 LossDict,
27 Params,
28 PolicyExtraInfo,
29 PyTreeData,
30 PyTreeDict,
31 State,
32 pytree_field,
33)
34from evorl.utils import running_statistics
35from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
36from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
37
38from evorl.agent import Agent, AgentState
39from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
40
41logger = logging.getLogger(__name__)
42
43
[docs]
44class DQNNetworkParams(PyTreeData):
45 q_params: Params
46 target_q_params: Params
47 exploration_epsilon: float
48
49
[docs]
50class DQNTrainMetric(MetricBase):
51 # no need reduce_fn since it's already reduced in the step()
52 loss: chex.Array = jnp.zeros(())
53 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
54
55
[docs]
56class DQNWorkflowMetric(WorkflowMetric):
57 training_updates: chex.Array = jnp.zeros((), dtype=jnp.uint32) # not need sync
58
59
[docs]
60class DQNAgent(Agent):
61 q_network: nn.Module
62 obs_preprocessor: Any = pytree_field(default=None, static=True)
63 discount: float = 0.99
64 target_type: str = "DDQN"
65
66 @property
67 def normalize_obs(self):
68 return self.obs_preprocessor is not None
69
[docs]
70 def init(
71 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
72 ) -> AgentState:
73 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
74
75 q_params = self.q_network.init(key, dummy_obs)
76 target_q_params = q_params
77
78 params_states = DQNNetworkParams(
79 q_params=q_params,
80 target_q_params=target_q_params,
81 exploration_epsilon=jnp.zeros(()), # handle at workflow
82 )
83
84 if self.normalize_obs:
85 # Note: statistics are broadcasted to [T*B]
86 obs_preprocessor_state = running_statistics.init_state(
87 tree_get(dummy_obs, 0)
88 )
89 else:
90 obs_preprocessor_state = None
91
92 return AgentState(
93 params=params_states, obs_preprocessor_state=obs_preprocessor_state
94 )
95
[docs]
96 def compute_actions(
97 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
98 ) -> tuple[Action, PolicyExtraInfo]:
99 obs = sample_batch.obs
100 if self.normalize_obs:
101 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
102
103 qs = self.q_network.apply(agent_state.params.q_params, obs)
104 # TODO: use tfp.Distribution
105 actions_dist = distrax.EpsilonGreedy(
106 qs, epsilon=agent_state.params.exploration_epsilon
107 )
108 # [B]: int from 0~(n-1)
109 actions = actions_dist.sample(seed=key)
110
111 return actions, PyTreeDict()
112
[docs]
113 def evaluate_actions(
114 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
115 ) -> tuple[Action, PolicyExtraInfo]:
116 obs = sample_batch.obs
117 if self.normalize_obs:
118 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
119
120 qs = self.q_network.apply(agent_state.params.q_params, sample_batch.obs)
121
122 actions_dist = distrax.EpsilonGreedy(
123 qs, epsilon=agent_state.params.exploration_epsilon
124 )
125 actions = actions_dist.mode()
126
127 return actions, PyTreeDict()
128
[docs]
129 def loss(
130 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
131 ) -> LossDict:
132 obs = sample_batch.obs
133 actions = sample_batch.actions
134 rewards = sample_batch.rewards
135 next_obs = sample_batch.extras.env_extras.ori_obs
136
137 if self.normalize_obs:
138 next_obs = self.obs_preprocessor(
139 next_obs, agent_state.obs_preprocessor_state
140 )
141 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
142
143 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination)
144
145 qs = self.q_network.apply(agent_state.params.q_params, obs)
146 # [B,n]->[B]
147 qs = jnp.take_along_axis(qs, actions[..., None], axis=-1).squeeze(-1)
148
149 # DQN_target: [B,n]
150 next_qs = self.q_network.apply(agent_state.params.target_q_params, next_obs)
151
152 if self.target_type == "DDQN":
153 next_actions = self.q_network.apply(
154 agent_state.params.q_params, next_obs
155 ).argmax(axis=-1, keepdims=True) # [B,1]
156 next_qs = jnp.take_along_axis(next_qs, next_actions, axis=-1).squeeze(-1)
157 elif self.target_type == "DQN":
158 next_qs = next_qs.max(axis=-1) # [B,n]->[B]
159 else:
160 raise ValueError(f"Unknown target_type: {self.target_type}")
161
162 qs_target = jax.lax.stop_gradient(rewards + discounts * next_qs)
163
164 q_loss = optax.squared_error(qs, qs_target).mean()
165
166 return PyTreeDict(q_loss=q_loss, q_value=qs.mean())
167
168
[docs]
169def make_mlp_discrete_dqn_agent(
170 action_space: Space,
171 discount: float = 0.99,
172 target_type: str = "DDQN",
173 q_hidden_layer_sizes: tuple[int] = (256, 256),
174 normalize_obs: bool = False,
175 value_obs_key: str = "",
176):
177 assert isinstance(action_space, Discrete), (
178 "Only Discrete action space is supported."
179 )
180
181 action_size = action_space.n
182 q_network = make_discrete_q_network(
183 action_size=action_size,
184 hidden_layer_sizes=q_hidden_layer_sizes,
185 obs_key=value_obs_key,
186 )
187
188 if normalize_obs:
189 obs_preprocessor = running_statistics.normalize
190 else:
191 obs_preprocessor = None
192
193 return DQNAgent(
194 q_network=q_network,
195 obs_preprocessor=obs_preprocessor,
196 discount=discount,
197 target_type=target_type,
198 )
199
200
[docs]
201class DQNWorkflow(OffPolicyWorkflowTemplate):
[docs]
202 @classmethod
203 def name(cls):
204 return "DQN"
205
206 @classmethod
207 def _build_from_config(cls, config: DictConfig):
208 env = create_env(
209 config.env,
210 episode_length=config.env.max_episode_steps,
211 parallel=config.num_envs,
212 autoreset_mode=AutoresetMode.NORMAL,
213 record_ori_obs=True,
214 )
215
216 assert isinstance(env.action_space, Discrete), (
217 "Only Discrete action space is supported."
218 )
219
220 agent = make_mlp_discrete_dqn_agent(
221 action_space=env.action_space,
222 discount=config.discount,
223 target_type=config.target_type,
224 q_hidden_layer_sizes=config.agent_network.q_hidden_layer_sizes,
225 normalize_obs=config.normalize_obs,
226 value_obs_key=config.agent_network.value_obs_key,
227 )
228
229 if (
230 config.optimizer.grad_clip_norm is not None
231 and config.optimizer.grad_clip_norm > 0
232 ):
233 optimizer = optax.chain(
234 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
235 optax.adam(config.optimizer.lr),
236 )
237 else:
238 optimizer = optax.adam(config.optimizer.lr)
239
240 replay_buffer = ReplayBuffer(
241 capacity=config.replay_buffer_capacity,
242 min_sample_timesteps=max(
243 config.batch_size, config.learning_start_timesteps
244 ),
245 sample_batch_size=config.batch_size,
246 )
247
248 eval_env = create_env(
249 config.env,
250 episode_length=config.env.max_episode_steps,
251 parallel=config.num_eval_envs,
252 autoreset_mode=AutoresetMode.DISABLED,
253 )
254
255 evaluator = Evaluator(
256 env=eval_env,
257 action_fn=agent.evaluate_actions,
258 max_episode_steps=config.env.max_episode_steps,
259 )
260
261 workflow = cls(env, agent, optimizer, evaluator, replay_buffer, config)
262
263 num_iterations = (
264 math.ceil(
265 config.total_timesteps
266 / (config.num_envs * config.rollout_length * config.fold_iters)
267 )
268 * config.fold_iters
269 )
270 total_training_updates = num_iterations * config.num_updates_per_iter
271 workflow.epsilon_scheduler = optax.linear_schedule(
272 init_value=config.exploration_epsilon.start,
273 end_value=config.exploration_epsilon.end,
274 transition_steps=(
275 config.exploration_epsilon.exploration_fraction * total_training_updates
276 )
277 - 1,
278 )
279
280 return workflow
281
282 def _setup_workflow_metrics(self) -> MetricBase:
283 return DQNWorkflowMetric()
284
285 def _setup_agent_and_optimizer(
286 self, key: chex.PRNGKey
287 ) -> tuple[AgentState, chex.ArrayTree]:
288 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
289 opt_state = self.optimizer.init(agent_state.params.q_params)
290
291 agent_state = agent_state.replace(
292 params=agent_state.params.replace(
293 exploration_epsilon=self.epsilon_scheduler(0)
294 )
295 )
296
297 return agent_state, opt_state
298
[docs]
299 def step(self, state: State) -> tuple[MetricBase, State]:
300 key, rollout_key, learn_key, buffer_key = jax.random.split(state.key, num=4)
301
302 # the trajectory [T, B, ...]
303 trajectory, env_state = rollout(
304 env_fn=self.env.step,
305 action_fn=self.agent.compute_actions,
306 env_state=state.env_state,
307 agent_state=state.agent_state,
308 key=rollout_key,
309 rollout_length=self.config.rollout_length,
310 env_extra_fields=("ori_obs", "termination"),
311 )
312
313 trajectory_dones = trajectory.dones
314 trajectory = clean_trajectory(trajectory)
315 trajectory = flatten_rollout_trajectory(trajectory)
316 trajectory = tree_stop_gradient(trajectory)
317
318 agent_state = state.agent_state
319 if agent_state.obs_preprocessor_state is not None:
320 agent_state = agent_state.replace(
321 obs_preprocessor_state=running_statistics.update(
322 agent_state.obs_preprocessor_state,
323 trajectory.obs,
324 dp_axis_name=self.dp_axis_name,
325 )
326 )
327
328 replay_buffer_state = self.replay_buffer.add(
329 state.replay_buffer_state, trajectory
330 )
331
332 def loss_fn(agent_state, sample_batch, key):
333 loss_dict = self.agent.loss(agent_state, sample_batch, key)
334 return loss_dict.q_loss, loss_dict
335
336 q_update_fn = agent_gradient_update(
337 loss_fn,
338 self.optimizer,
339 dp_axis_name=self.dp_axis_name,
340 has_aux=True,
341 attach_fn=lambda agent_state, q_params: agent_state.replace(
342 params=agent_state.params.replace(q_params=q_params)
343 ),
344 detach_fn=lambda agent_state: agent_state.params.q_params,
345 )
346
347 workflow_metrics = state.metrics
348
349 def _sample_and_update_fn(carry, unused_t):
350 key, agent_state, opt_state, wf_metrics = carry
351
352 key, rb_key, q_key = jax.random.split(key, 3)
353
354 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
355
356 (q_loss, loss_dict), agent_state, opt_state = q_update_fn(
357 opt_state, agent_state, sample_batch, q_key
358 )
359
360 wf_metrics = wf_metrics.replace(
361 training_updates=wf_metrics.training_updates + 1
362 )
363
364 def _soft_update_q(agent_state):
365 target_q_params = soft_target_update(
366 agent_state.params.target_q_params,
367 agent_state.params.q_params,
368 self.config.tau,
369 )
370 return agent_state.replace(
371 params=agent_state.params.replace(target_q_params=target_q_params)
372 )
373
374 agent_state = jax.lax.cond(
375 wf_metrics.training_updates % self.config.target_network_update_freq
376 == 0,
377 _soft_update_q,
378 lambda agent_state: agent_state,
379 agent_state,
380 )
381
382 agent_state = agent_state.replace(
383 params=agent_state.params.replace(
384 exploration_epsilon=self.epsilon_scheduler(
385 wf_metrics.training_updates
386 )
387 )
388 )
389
390 return (key, agent_state, opt_state, wf_metrics), (q_loss, loss_dict)
391
392 (_, agent_state, opt_state, workflow_metrics), (q_loss, loss_dict) = (
393 scan_and_mean(
394 _sample_and_update_fn,
395 (learn_key, agent_state, state.opt_state, state.metrics),
396 (),
397 length=self.config.num_updates_per_iter,
398 )
399 )
400
401 train_metrics = DQNTrainMetric(
402 loss=q_loss,
403 raw_loss_dict=loss_dict,
404 )
405
406 # calculate the number of timestep
407 sampled_timesteps = psum(
408 jnp.uint32(self.config.rollout_length * self.config.num_envs),
409 axis_name=self.dp_axis_name,
410 )
411 sampled_epsiodes = psum(
412 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
413 )
414
415 workflow_metrics = workflow_metrics.replace(
416 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
417 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
418 iterations=state.metrics.iterations + 1,
419 ).all_reduce(dp_axis_name=self.dp_axis_name)
420
421 return train_metrics, state.replace(
422 key=key,
423 metrics=workflow_metrics,
424 agent_state=agent_state,
425 env_state=env_state,
426 replay_buffer_state=replay_buffer_state,
427 opt_state=opt_state,
428 )