Source code for evorl.algorithms.dqn

  1import logging
  2import math
  3from typing import Any
  4
  5import chex
  6import distrax
  7
  8import flax.linen as nn
  9import jax
 10import jax.numpy as jnp
 11import jax.tree_util as jtu
 12import optax
 13from omegaconf import DictConfig
 14
 15from evorl.replay_buffers import ReplayBuffer
 16from evorl.distributed import psum, pmean
 17from evorl.distributed.gradients import agent_gradient_update
 18from evorl.envs import AutoresetMode, Discrete, create_env, Space
 19from evorl.evaluators import Evaluator
 20from evorl.metrics import MetricBase, WorkflowMetric, metric_field
 21from evorl.networks import make_discrete_q_network
 22from evorl.rollout import rollout
 23from evorl.sample_batch import SampleBatch
 24from evorl.types import (
 25    Action,
 26    LossDict,
 27    Params,
 28    PolicyExtraInfo,
 29    PyTreeData,
 30    PyTreeDict,
 31    State,
 32    pytree_field,
 33)
 34from evorl.utils import running_statistics
 35from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
 36from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 37
 38from evorl.agent import Agent, AgentState
 39from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
 40
 41logger = logging.getLogger(__name__)
 42
 43
[docs] 44class DQNNetworkParams(PyTreeData): 45 q_params: Params 46 target_q_params: Params 47 exploration_epsilon: float
48 49
[docs] 50class DQNTrainMetric(MetricBase): 51 # no need reduce_fn since it's already reduced in the step() 52 loss: chex.Array = jnp.zeros(()) 53 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
54 55
[docs] 56class DQNWorkflowMetric(WorkflowMetric): 57 training_updates: chex.Array = jnp.zeros((), dtype=jnp.uint32) # not need sync
58 59
[docs] 60class DQNAgent(Agent): 61 q_network: nn.Module 62 obs_preprocessor: Any = pytree_field(default=None, static=True) 63 discount: float = 0.99 64 target_type: str = "DDQN" 65 66 @property 67 def normalize_obs(self): 68 return self.obs_preprocessor is not None 69
[docs] 70 def init( 71 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 72 ) -> AgentState: 73 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 74 75 q_params = self.q_network.init(key, dummy_obs) 76 target_q_params = q_params 77 78 params_states = DQNNetworkParams( 79 q_params=q_params, 80 target_q_params=target_q_params, 81 exploration_epsilon=jnp.zeros(()), # handle at workflow 82 ) 83 84 if self.normalize_obs: 85 # Note: statistics are broadcasted to [T*B] 86 obs_preprocessor_state = running_statistics.init_state( 87 tree_get(dummy_obs, 0) 88 ) 89 else: 90 obs_preprocessor_state = None 91 92 return AgentState( 93 params=params_states, obs_preprocessor_state=obs_preprocessor_state 94 )
95
[docs] 96 def compute_actions( 97 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 98 ) -> tuple[Action, PolicyExtraInfo]: 99 obs = sample_batch.obs 100 if self.normalize_obs: 101 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 102 103 qs = self.q_network.apply(agent_state.params.q_params, obs) 104 # TODO: use tfp.Distribution 105 actions_dist = distrax.EpsilonGreedy( 106 qs, epsilon=agent_state.params.exploration_epsilon 107 ) 108 # [B]: int from 0~(n-1) 109 actions = actions_dist.sample(seed=key) 110 111 return actions, PyTreeDict()
112
[docs] 113 def evaluate_actions( 114 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 115 ) -> tuple[Action, PolicyExtraInfo]: 116 obs = sample_batch.obs 117 if self.normalize_obs: 118 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 119 120 qs = self.q_network.apply(agent_state.params.q_params, sample_batch.obs) 121 122 actions_dist = distrax.EpsilonGreedy( 123 qs, epsilon=agent_state.params.exploration_epsilon 124 ) 125 actions = actions_dist.mode() 126 127 return actions, PyTreeDict()
128
[docs] 129 def loss( 130 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 131 ) -> LossDict: 132 obs = sample_batch.obs 133 actions = sample_batch.actions 134 rewards = sample_batch.rewards 135 next_obs = sample_batch.extras.env_extras.ori_obs 136 137 if self.normalize_obs: 138 next_obs = self.obs_preprocessor( 139 next_obs, agent_state.obs_preprocessor_state 140 ) 141 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 142 143 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 144 145 qs = self.q_network.apply(agent_state.params.q_params, obs) 146 # [B,n]->[B] 147 qs = jnp.take_along_axis(qs, actions[..., None], axis=-1).squeeze(-1) 148 149 # DQN_target: [B,n] 150 next_qs = self.q_network.apply(agent_state.params.target_q_params, next_obs) 151 152 if self.target_type == "DDQN": 153 next_actions = self.q_network.apply( 154 agent_state.params.q_params, next_obs 155 ).argmax(axis=-1, keepdims=True) # [B,1] 156 next_qs = jnp.take_along_axis(next_qs, next_actions, axis=-1).squeeze(-1) 157 elif self.target_type == "DQN": 158 next_qs = next_qs.max(axis=-1) # [B,n]->[B] 159 else: 160 raise ValueError(f"Unknown target_type: {self.target_type}") 161 162 qs_target = jax.lax.stop_gradient(rewards + discounts * next_qs) 163 164 q_loss = optax.squared_error(qs, qs_target).mean() 165 166 return PyTreeDict(q_loss=q_loss, q_value=qs.mean())
167 168
[docs] 169def make_mlp_discrete_dqn_agent( 170 action_space: Space, 171 discount: float = 0.99, 172 target_type: str = "DDQN", 173 q_hidden_layer_sizes: tuple[int] = (256, 256), 174 normalize_obs: bool = False, 175 value_obs_key: str = "", 176): 177 assert isinstance(action_space, Discrete), ( 178 "Only Discrete action space is supported." 179 ) 180 181 action_size = action_space.n 182 q_network = make_discrete_q_network( 183 action_size=action_size, 184 hidden_layer_sizes=q_hidden_layer_sizes, 185 obs_key=value_obs_key, 186 ) 187 188 if normalize_obs: 189 obs_preprocessor = running_statistics.normalize 190 else: 191 obs_preprocessor = None 192 193 return DQNAgent( 194 q_network=q_network, 195 obs_preprocessor=obs_preprocessor, 196 discount=discount, 197 target_type=target_type, 198 )
199 200
[docs] 201class DQNWorkflow(OffPolicyWorkflowTemplate):
[docs] 202 @classmethod 203 def name(cls): 204 return "DQN"
205 206 @classmethod 207 def _build_from_config(cls, config: DictConfig): 208 env = create_env( 209 config.env, 210 episode_length=config.env.max_episode_steps, 211 parallel=config.num_envs, 212 autoreset_mode=AutoresetMode.NORMAL, 213 record_ori_obs=True, 214 ) 215 216 assert isinstance(env.action_space, Discrete), ( 217 "Only Discrete action space is supported." 218 ) 219 220 agent = make_mlp_discrete_dqn_agent( 221 action_space=env.action_space, 222 discount=config.discount, 223 target_type=config.target_type, 224 q_hidden_layer_sizes=config.agent_network.q_hidden_layer_sizes, 225 normalize_obs=config.normalize_obs, 226 value_obs_key=config.agent_network.value_obs_key, 227 ) 228 229 if ( 230 config.optimizer.grad_clip_norm is not None 231 and config.optimizer.grad_clip_norm > 0 232 ): 233 optimizer = optax.chain( 234 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 235 optax.adam(config.optimizer.lr), 236 ) 237 else: 238 optimizer = optax.adam(config.optimizer.lr) 239 240 replay_buffer = ReplayBuffer( 241 capacity=config.replay_buffer_capacity, 242 min_sample_timesteps=max( 243 config.batch_size, config.learning_start_timesteps 244 ), 245 sample_batch_size=config.batch_size, 246 ) 247 248 eval_env = create_env( 249 config.env, 250 episode_length=config.env.max_episode_steps, 251 parallel=config.num_eval_envs, 252 autoreset_mode=AutoresetMode.DISABLED, 253 ) 254 255 evaluator = Evaluator( 256 env=eval_env, 257 action_fn=agent.evaluate_actions, 258 max_episode_steps=config.env.max_episode_steps, 259 ) 260 261 workflow = cls(env, agent, optimizer, evaluator, replay_buffer, config) 262 263 num_iterations = ( 264 math.ceil( 265 config.total_timesteps 266 / (config.num_envs * config.rollout_length * config.fold_iters) 267 ) 268 * config.fold_iters 269 ) 270 total_training_updates = num_iterations * config.num_updates_per_iter 271 workflow.epsilon_scheduler = optax.linear_schedule( 272 init_value=config.exploration_epsilon.start, 273 end_value=config.exploration_epsilon.end, 274 transition_steps=( 275 config.exploration_epsilon.exploration_fraction * total_training_updates 276 ) 277 - 1, 278 ) 279 280 return workflow 281 282 def _setup_workflow_metrics(self) -> MetricBase: 283 return DQNWorkflowMetric() 284 285 def _setup_agent_and_optimizer( 286 self, key: chex.PRNGKey 287 ) -> tuple[AgentState, chex.ArrayTree]: 288 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 289 opt_state = self.optimizer.init(agent_state.params.q_params) 290 291 agent_state = agent_state.replace( 292 params=agent_state.params.replace( 293 exploration_epsilon=self.epsilon_scheduler(0) 294 ) 295 ) 296 297 return agent_state, opt_state 298
[docs] 299 def step(self, state: State) -> tuple[MetricBase, State]: 300 key, rollout_key, learn_key, buffer_key = jax.random.split(state.key, num=4) 301 302 # the trajectory [T, B, ...] 303 trajectory, env_state = rollout( 304 env_fn=self.env.step, 305 action_fn=self.agent.compute_actions, 306 env_state=state.env_state, 307 agent_state=state.agent_state, 308 key=rollout_key, 309 rollout_length=self.config.rollout_length, 310 env_extra_fields=("ori_obs", "termination"), 311 ) 312 313 trajectory_dones = trajectory.dones 314 trajectory = clean_trajectory(trajectory) 315 trajectory = flatten_rollout_trajectory(trajectory) 316 trajectory = tree_stop_gradient(trajectory) 317 318 agent_state = state.agent_state 319 if agent_state.obs_preprocessor_state is not None: 320 agent_state = agent_state.replace( 321 obs_preprocessor_state=running_statistics.update( 322 agent_state.obs_preprocessor_state, 323 trajectory.obs, 324 dp_axis_name=self.dp_axis_name, 325 ) 326 ) 327 328 replay_buffer_state = self.replay_buffer.add( 329 state.replay_buffer_state, trajectory 330 ) 331 332 def loss_fn(agent_state, sample_batch, key): 333 loss_dict = self.agent.loss(agent_state, sample_batch, key) 334 return loss_dict.q_loss, loss_dict 335 336 q_update_fn = agent_gradient_update( 337 loss_fn, 338 self.optimizer, 339 dp_axis_name=self.dp_axis_name, 340 has_aux=True, 341 attach_fn=lambda agent_state, q_params: agent_state.replace( 342 params=agent_state.params.replace(q_params=q_params) 343 ), 344 detach_fn=lambda agent_state: agent_state.params.q_params, 345 ) 346 347 workflow_metrics = state.metrics 348 349 def _sample_and_update_fn(carry, unused_t): 350 key, agent_state, opt_state, wf_metrics = carry 351 352 key, rb_key, q_key = jax.random.split(key, 3) 353 354 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 355 356 (q_loss, loss_dict), agent_state, opt_state = q_update_fn( 357 opt_state, agent_state, sample_batch, q_key 358 ) 359 360 wf_metrics = wf_metrics.replace( 361 training_updates=wf_metrics.training_updates + 1 362 ) 363 364 def _soft_update_q(agent_state): 365 target_q_params = soft_target_update( 366 agent_state.params.target_q_params, 367 agent_state.params.q_params, 368 self.config.tau, 369 ) 370 return agent_state.replace( 371 params=agent_state.params.replace(target_q_params=target_q_params) 372 ) 373 374 agent_state = jax.lax.cond( 375 wf_metrics.training_updates % self.config.target_network_update_freq 376 == 0, 377 _soft_update_q, 378 lambda agent_state: agent_state, 379 agent_state, 380 ) 381 382 agent_state = agent_state.replace( 383 params=agent_state.params.replace( 384 exploration_epsilon=self.epsilon_scheduler( 385 wf_metrics.training_updates 386 ) 387 ) 388 ) 389 390 return (key, agent_state, opt_state, wf_metrics), (q_loss, loss_dict) 391 392 (_, agent_state, opt_state, workflow_metrics), (q_loss, loss_dict) = ( 393 scan_and_mean( 394 _sample_and_update_fn, 395 (learn_key, agent_state, state.opt_state, state.metrics), 396 (), 397 length=self.config.num_updates_per_iter, 398 ) 399 ) 400 401 train_metrics = DQNTrainMetric( 402 loss=q_loss, 403 raw_loss_dict=loss_dict, 404 ) 405 406 # calculate the number of timestep 407 sampled_timesteps = psum( 408 jnp.uint32(self.config.rollout_length * self.config.num_envs), 409 axis_name=self.dp_axis_name, 410 ) 411 sampled_epsiodes = psum( 412 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 413 ) 414 415 workflow_metrics = workflow_metrics.replace( 416 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 417 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 418 iterations=state.metrics.iterations + 1, 419 ).all_reduce(dp_axis_name=self.dp_axis_name) 420 421 return train_metrics, state.replace( 422 key=key, 423 metrics=workflow_metrics, 424 agent_state=agent_state, 425 env_state=env_state, 426 replay_buffer_state=replay_buffer_state, 427 opt_state=opt_state, 428 )