Source code for evorl.algorithms.ddpg

  1import logging
  2from typing import Any
  3
  4import chex
  5import flax.linen as nn
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10from omegaconf import DictConfig
 11
 12from evorl.distributed import psum, pmean
 13from evorl.distributed.gradients import agent_gradient_update
 14from evorl.envs import AutoresetMode, Box, create_env, Space
 15from evorl.evaluators import Evaluator
 16from evorl.metrics import MetricBase, metric_field
 17from evorl.networks import make_policy_network, make_q_network
 18from evorl.rollout import rollout
 19from evorl.replay_buffers import ReplayBuffer
 20from evorl.sample_batch import SampleBatch
 21from evorl.types import (
 22    Action,
 23    LossDict,
 24    Params,
 25    PolicyExtraInfo,
 26    PyTreeData,
 27    PyTreeDict,
 28    State,
 29    pytree_field,
 30)
 31from evorl.utils import running_statistics
 32from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient, tree_get
 33from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 34
 35from evorl.agent import Agent, AgentState
 36from .offpolicy_utils import OffPolicyWorkflowTemplate, clean_trajectory
 37
 38logger = logging.getLogger(__name__)
 39
 40
[docs] 41class DDPGTrainMetric(MetricBase): 42 actor_loss: chex.Array 43 critic_loss: chex.Array 44 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
45 46
[docs] 47class DDPGNetworkParams(PyTreeData): 48 """Contains training state for the learner.""" 49 50 actor_params: Params 51 critic_params: Params 52 53 target_actor_params: Params 54 target_critic_params: Params
55 56
[docs] 57class DDPGAgent(Agent): 58 """The Agnet for DDPG.""" 59 60 critic_network: nn.Module 61 actor_network: nn.Module 62 obs_preprocessor: Any = pytree_field(default=None, static=True) 63 64 discount: float = 1 65 exploration_epsilon: float = 0.5 66 67 @property 68 def normalize_obs(self): 69 return self.obs_preprocessor is not None 70
[docs] 71 def init( 72 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 73 ) -> AgentState: 74 key, q_key, actor_key = jax.random.split(key, num=3) 75 76 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 77 dummy_action = action_space.sample(key)[None, ...] 78 79 critic_params = self.critic_network.init(q_key, dummy_obs, dummy_action) 80 target_critic_params = critic_params 81 82 actor_params = self.actor_network.init(actor_key, dummy_obs) 83 target_actor_params = actor_params 84 85 params_state = DDPGNetworkParams( 86 critic_params=critic_params, 87 actor_params=actor_params, 88 target_critic_params=target_critic_params, 89 target_actor_params=target_actor_params, 90 ) 91 92 if self.normalize_obs: 93 # Note: statistics are broadcasted to [T*B] 94 obs_preprocessor_state = running_statistics.init_state( 95 tree_get(dummy_obs, 0) 96 ) 97 else: 98 obs_preprocessor_state = None 99 100 return AgentState( 101 params=params_state, obs_preprocessor_state=obs_preprocessor_state 102 )
103
[docs] 104 def compute_actions( 105 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 106 ) -> tuple[Action, PolicyExtraInfo]: 107 obs = sample_batch.obs 108 if self.normalize_obs: 109 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 110 111 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 112 # add random noise 113 noise = jax.random.normal(key, actions.shape) * self.exploration_epsilon 114 actions += noise 115 actions = jnp.clip(actions, -1.0, 1.0) 116 117 return actions, PyTreeDict()
118
[docs] 119 def evaluate_actions( 120 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 121 ) -> tuple[Action, PolicyExtraInfo]: 122 obs = sample_batch.obs 123 if self.normalize_obs: 124 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 125 126 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 127 128 return actions, PyTreeDict()
129
[docs] 130 def critic_loss( 131 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 132 ) -> LossDict: 133 next_obs = sample_batch.extras.env_extras.ori_obs 134 obs = sample_batch.obs 135 actions = sample_batch.actions 136 137 if self.normalize_obs: 138 next_obs = self.obs_preprocessor( 139 next_obs, agent_state.obs_preprocessor_state 140 ) 141 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 142 143 next_actions = self.actor_network.apply( 144 agent_state.params.target_actor_params, next_obs 145 ) 146 147 next_qs = self.critic_network.apply( 148 agent_state.params.target_critic_params, next_obs, next_actions 149 ) 150 151 discounts = self.discount * (1 - sample_batch.extras.env_extras.termination) 152 153 qs_target = sample_batch.rewards + discounts * next_qs 154 qs_target = jax.lax.stop_gradient(qs_target) 155 156 qs = self.critic_network.apply(agent_state.params.critic_params, obs, actions) 157 158 # q_loss = optax.huber_loss(qs, target_qs, delta=1).mean() 159 q_loss = optax.squared_error(qs, qs_target).mean() 160 161 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
162
[docs] 163 def actor_loss( 164 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 165 ) -> LossDict: 166 obs = sample_batch.obs 167 168 if self.normalize_obs: 169 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 170 171 # [T*B, A] 172 actions = self.actor_network.apply(agent_state.params.actor_params, obs) 173 174 actor_loss = -jnp.mean( 175 self.critic_network.apply(agent_state.params.critic_params, obs, actions) 176 ) 177 return PyTreeDict(actor_loss=actor_loss)
178 179
[docs] 180def make_mlp_ddpg_agent( 181 action_space: Space, 182 critic_hidden_layer_sizes: tuple[int] = (256, 256), 183 actor_hidden_layer_sizes: tuple[int] = (256, 256), 184 discount: float = 1, 185 exploration_epsilon: float = 0.5, 186 normalize_obs: bool = False, 187 policy_obs_key: str = "", 188 value_obs_key: str = "", 189): 190 assert isinstance(action_space, Box), "Only continue action space is supported." 191 192 action_size = action_space.shape[0] 193 194 critic_network = make_q_network( 195 hidden_layer_sizes=critic_hidden_layer_sizes, 196 obs_key=value_obs_key, 197 ) 198 actor_network = make_policy_network( 199 action_size=action_size, 200 hidden_layer_sizes=actor_hidden_layer_sizes, 201 activation_final=nn.tanh, 202 obs_key=policy_obs_key, 203 ) 204 205 if normalize_obs: 206 obs_preprocessor = running_statistics.normalize 207 else: 208 obs_preprocessor = None 209 210 return DDPGAgent( 211 critic_network=critic_network, 212 actor_network=actor_network, 213 obs_preprocessor=obs_preprocessor, 214 discount=discount, 215 exploration_epsilon=exploration_epsilon, 216 )
217 218
[docs] 219class DDPGWorkflow(OffPolicyWorkflowTemplate):
[docs] 220 @classmethod 221 def name(cls): 222 return "DDPG"
223 224 @classmethod 225 def _build_from_config(cls, config: DictConfig): 226 env = create_env( 227 config.env, 228 episode_length=config.env.max_episode_steps, 229 parallel=config.num_envs, 230 autoreset_mode=AutoresetMode.NORMAL, 231 record_ori_obs=True, 232 ) 233 234 agent = make_mlp_ddpg_agent( 235 action_space=env.action_space, 236 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 237 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 238 discount=config.discount, 239 exploration_epsilon=config.exploration_epsilon, 240 normalize_obs=config.normalize_obs, 241 policy_obs_key=config.agent_network.policy_obs_key, 242 value_obs_key=config.agent_network.value_obs_key, 243 ) 244 245 # one optimizer, two opt_states (in setup function) for both actor and critic 246 if ( 247 config.optimizer.grad_clip_norm is not None 248 and config.optimizer.grad_clip_norm > 0 249 ): 250 optimizer = optax.chain( 251 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 252 optax.adam(config.optimizer.lr), 253 ) 254 else: 255 optimizer = optax.adam(config.optimizer.lr) 256 257 replay_buffer = ReplayBuffer( 258 capacity=config.replay_buffer_capacity, 259 min_sample_timesteps=max( 260 config.batch_size, config.learning_start_timesteps 261 ), 262 sample_batch_size=config.batch_size, 263 ) 264 265 eval_env = create_env( 266 config.env, 267 episode_length=config.env.max_episode_steps, 268 parallel=config.num_eval_envs, 269 autoreset_mode=AutoresetMode.DISABLED, 270 ) 271 272 evaluator = Evaluator( 273 env=eval_env, 274 action_fn=agent.evaluate_actions, 275 max_episode_steps=config.env.max_episode_steps, 276 ) 277 278 return cls( 279 env, 280 agent, 281 optimizer, 282 evaluator, 283 replay_buffer, 284 config, 285 ) 286 287 def _setup_agent_and_optimizer( 288 self, key: chex.PRNGKey 289 ) -> tuple[AgentState, chex.ArrayTree]: 290 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 291 opt_state = PyTreeDict( 292 dict( 293 actor=self.optimizer.init(agent_state.params.actor_params), 294 critic=self.optimizer.init(agent_state.params.critic_params), 295 ) 296 ) 297 return agent_state, opt_state 298
[docs] 299 def step(self, state: State) -> tuple[MetricBase, State]: 300 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 301 302 # the trajectory [T, B, ...] 303 trajectory, env_state = rollout( 304 env_fn=self.env.step, 305 action_fn=self.agent.compute_actions, 306 env_state=state.env_state, 307 agent_state=state.agent_state, 308 key=rollout_key, 309 rollout_length=self.config.rollout_length, 310 env_extra_fields=("ori_obs", "termination"), 311 ) 312 313 trajectory_dones = trajectory.dones 314 trajectory = clean_trajectory(trajectory) 315 trajectory = flatten_rollout_trajectory(trajectory) 316 trajectory = tree_stop_gradient(trajectory) 317 318 agent_state = state.agent_state 319 if agent_state.obs_preprocessor_state is not None: 320 agent_state = agent_state.replace( 321 obs_preprocessor_state=running_statistics.update( 322 agent_state.obs_preprocessor_state, 323 trajectory.obs, 324 dp_axis_name=self.dp_axis_name, 325 ) 326 ) 327 328 replay_buffer_state = self.replay_buffer.add( 329 state.replay_buffer_state, trajectory 330 ) 331 332 def critic_loss_fn(agent_state, sample_batch, key): 333 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 334 335 loss = loss_dict.critic_loss 336 return loss, loss_dict 337 338 def actor_loss_fn(agent_state, sample_batch, key): 339 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 340 341 loss = loss_dict.actor_loss 342 return loss, loss_dict 343 344 critic_update_fn = agent_gradient_update( 345 critic_loss_fn, 346 self.optimizer, 347 dp_axis_name=self.dp_axis_name, 348 has_aux=True, 349 attach_fn=lambda agent_state, critic_params: agent_state.replace( 350 params=agent_state.params.replace(critic_params=critic_params) 351 ), 352 detach_fn=lambda agent_state: agent_state.params.critic_params, 353 ) 354 355 actor_update_fn = agent_gradient_update( 356 actor_loss_fn, 357 self.optimizer, 358 dp_axis_name=self.dp_axis_name, 359 has_aux=True, 360 attach_fn=lambda agent_state, actor_params: agent_state.replace( 361 params=agent_state.params.replace(actor_params=actor_params) 362 ), 363 detach_fn=lambda agent_state: agent_state.params.actor_params, 364 ) 365 366 def _sample_and_update_fn(carry, unused_t): 367 key, agent_state, opt_state = carry 368 369 key, rb_key, critic_key, actor_key = jax.random.split(key, 4) 370 371 critic_opt_state = opt_state.critic 372 actor_opt_state = opt_state.actor 373 374 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 375 376 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 377 critic_update_fn( 378 opt_state.critic, agent_state, sample_batch, critic_key 379 ) 380 ) 381 382 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 383 actor_update_fn(opt_state.actor, agent_state, sample_batch, actor_key) 384 ) 385 386 target_actor_params = soft_target_update( 387 agent_state.params.target_actor_params, 388 agent_state.params.actor_params, 389 self.config.tau, 390 ) 391 target_critic_params = soft_target_update( 392 agent_state.params.target_critic_params, 393 agent_state.params.critic_params, 394 self.config.tau, 395 ) 396 agent_state = agent_state.replace( 397 params=agent_state.params.replace( 398 target_actor_params=target_actor_params, 399 target_critic_params=target_critic_params, 400 ) 401 ) 402 403 opt_state = PyTreeDict(actor=actor_opt_state, critic=critic_opt_state) 404 405 return ( 406 (key, agent_state, opt_state), 407 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 408 ) 409 410 ( 411 (_, agent_state, opt_state), 412 ( 413 critic_loss, 414 actor_loss, 415 critic_loss_dict, 416 actor_loss_dict, 417 ), 418 ) = scan_and_mean( 419 _sample_and_update_fn, 420 (learn_key, agent_state, state.opt_state), 421 (), 422 length=self.config.num_updates_per_iter, 423 ) 424 425 train_metrics = DDPGTrainMetric( 426 actor_loss=actor_loss, 427 critic_loss=critic_loss, 428 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 429 ).all_reduce(dp_axis_name=self.dp_axis_name) 430 431 # calculate the number of timestep 432 sampled_timesteps = psum( 433 jnp.uint32(self.config.rollout_length * self.config.num_envs), 434 axis_name=self.dp_axis_name, 435 ) 436 sampled_epsiodes = psum( 437 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 438 ) 439 440 # iterations is the number of updates of the agent 441 workflow_metrics = state.metrics.replace( 442 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 443 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 444 iterations=state.metrics.iterations + 1, 445 ).all_reduce(dp_axis_name=self.dp_axis_name) 446 447 return train_metrics, state.replace( 448 key=key, 449 metrics=workflow_metrics, 450 agent_state=agent_state, 451 env_state=env_state, 452 replay_buffer_state=replay_buffer_state, 453 opt_state=opt_state, 454 )