Source code for evorl.algorithms.erl.erl_td3.erl_eda

  1import logging
  2import math
  3from omegaconf import DictConfig
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10
 11from evorl.replay_buffers import ReplayBuffer
 12from evorl.distributed import agent_gradient_update
 13from evorl.metrics import MetricBase
 14from evorl.types import PyTreeDict, State
 15from evorl.utils.jax_utils import (
 16    tree_stop_gradient,
 17    rng_split_like_tree,
 18    right_shift_with_padding,
 19    scan_and_mean,
 20)
 21from evorl.utils.rl_toolkits import soft_target_update, flatten_rollout_trajectory
 22from evorl.evaluators import Evaluator, EpisodeCollector
 23from evorl.agent import AgentState, Agent
 24from evorl.envs import create_env, AutoresetMode
 25from evorl.recorders import get_1d_array_statistics, add_prefix
 26from evorl.ec.optimizers import ECState, OpenES, ExponentialScheduleSpec
 27from evorl.algorithms.td3 import make_mlp_td3_agent, TD3TrainMetric
 28from evorl.algorithms.offpolicy_utils import clean_trajectory, skip_replay_buffer_state
 29
 30from ..erl_workflow import ERLTrainMetric
 31from .erl_td3_workflow import erl_replace_td3_actor_params, ERLTD3WorkflowTemplate
 32
 33logger = logging.getLogger(__name__)
 34
 35
[docs] 36class EvaluateMetric(MetricBase): 37 rl_episode_returns: chex.Array 38 rl_episode_lengths: chex.Array 39 pop_center_episode_returns: chex.Array 40 pop_center_episode_lengths: chex.Array
41 42
[docs] 43class ERLEDAWorkflow(ERLTD3WorkflowTemplate): 44 """ERL w/ EDA. 45 46 Configs: 47 48 - EC: n actors 49 - RL: 1 (actor,critic) 50 - Shared replay buffer 51 52 RL will be injected into the pop mean. Support all EDA based ES algorithms. 53 """ 54 55 def __init__(self, **kwargs): 56 super().__init__(**kwargs) 57 58 # override 59 self._rl_update_fn = build_rl_update_fn(self.agent, self.optimizer, self.config) 60
[docs] 61 @classmethod 62 def name(cls): 63 return "ERL-EDA"
64 65 @classmethod 66 def _build_from_config(cls, config: DictConfig): 67 # env for rl&ec rollout 68 env = create_env( 69 config.env, 70 episode_length=config.env.max_episode_steps, 71 parallel=config.num_envs, 72 autoreset_mode=AutoresetMode.DISABLED, 73 record_ori_obs=True, 74 ) 75 76 agent = make_mlp_td3_agent( 77 action_space=env.action_space, 78 norm_layer_type=config.agent_network.norm_layer_type, 79 num_critics=config.agent_network.num_critics, 80 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 81 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 82 discount=config.discount, 83 exploration_epsilon=config.exploration_epsilon, 84 policy_noise=config.policy_noise, 85 clip_policy_noise=config.clip_policy_noise, 86 critics_in_actor_loss=config.critics_in_actor_loss, 87 normalize_obs=config.normalize_obs, 88 ) 89 90 if ( 91 config.optimizer.grad_clip_norm is not None 92 and config.optimizer.grad_clip_norm > 0 93 ): 94 optimizer = optax.chain( 95 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 96 optax.adam(config.optimizer.lr), 97 ) 98 else: 99 optimizer = optax.adam(config.optimizer.lr) 100 101 ec_optimizer = OpenES( 102 pop_size=config.pop_size, 103 lr_schedule=ExponentialScheduleSpec(**config.ec_lr), 104 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std), 105 mirror_sampling=config.mirror_sampling, 106 ) 107 108 if config.fitness_with_exploration: 109 action_fn = agent.compute_actions 110 else: 111 action_fn = agent.evaluate_actions 112 113 ec_collector = EpisodeCollector( 114 env=env, 115 action_fn=action_fn, 116 max_episode_steps=config.env.max_episode_steps, 117 env_extra_fields=("ori_obs", "termination"), 118 ) 119 120 if config.rl_exploration: 121 action_fn = agent.compute_actions 122 else: 123 action_fn = agent.evaluate_actions 124 125 rl_collector = EpisodeCollector( 126 env=env, 127 action_fn=action_fn, 128 max_episode_steps=config.env.max_episode_steps, 129 env_extra_fields=("ori_obs", "termination"), 130 ) 131 132 replay_buffer = ReplayBuffer( 133 capacity=config.replay_buffer_capacity, 134 min_sample_timesteps=config.batch_size, 135 sample_batch_size=config.batch_size, 136 ) 137 138 # to evaluate the pop-mean actor 139 eval_env = create_env( 140 config.env, 141 episode_length=config.env.max_episode_steps, 142 parallel=config.num_eval_envs, 143 autoreset_mode=AutoresetMode.DISABLED, 144 ) 145 146 evaluator = Evaluator( 147 env=eval_env, 148 action_fn=agent.evaluate_actions, 149 max_episode_steps=config.env.max_episode_steps, 150 ) 151 152 # this is only used for _ec_rollout() 153 agent_state_vmap_axes = AgentState( 154 params=0, 155 obs_preprocessor_state=None, 156 ) 157 158 workflow = cls( 159 env=env, 160 agent=agent, 161 agent_state_vmap_axes=agent_state_vmap_axes, 162 optimizer=optimizer, 163 ec_optimizer=ec_optimizer, 164 ec_collector=ec_collector, 165 rl_collector=rl_collector, 166 evaluator=evaluator, 167 replay_buffer=replay_buffer, 168 config=config, 169 ) 170 171 return workflow 172 173 def _setup_agent_and_optimizer( 174 self, key: chex.PRNGKey 175 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 176 agent_key, ec_key = jax.random.split(key) 177 178 # one agent for RL 179 agent_state = self.agent.init( 180 self.env.obs_space, self.env.action_space, agent_key 181 ) 182 183 init_actor_params = agent_state.params.actor_params 184 185 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 186 187 opt_state = PyTreeDict( 188 actor=self.optimizer.init(agent_state.params.actor_params), 189 critic=self.optimizer.init(agent_state.params.critic_params), 190 ) 191 192 return agent_state, opt_state, ec_opt_state 193 194 # override 195 def _rl_rollout(self, agent_state, replay_buffer_state, key): 196 # agnet_state: only contains one agent 197 # trajectory [T, B, ...] 198 eval_metrics, trajectory = self.rl_collector.rollout( 199 agent_state, 200 key, 201 self.config.rollout_episodes, 202 ) 203 204 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1)) 205 trajectory = clean_trajectory(trajectory) 206 trajectory, mask = tree_stop_gradient( 207 flatten_rollout_trajectory((trajectory, mask)) 208 ) 209 replay_buffer_state = self.replay_buffer.add( 210 replay_buffer_state, trajectory, mask 211 ) 212 213 return eval_metrics, trajectory, replay_buffer_state 214 215 # override 216 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key): 217 def _sample_fn(key): 218 return self.replay_buffer.sample(replay_buffer_state, key) 219 220 def _sample_and_update_fn(carry, unused_t): 221 key, agent_state, opt_state = carry 222 223 key, rb_key, learn_key = jax.random.split(key, 3) 224 225 rb_keys = jax.random.split(rb_key, self.config.actor_update_interval) 226 # (actor_update_interval, B, ...) 227 sample_batches = jax.vmap(_sample_fn)(rb_keys) 228 229 (agent_state, opt_state), train_info = self._rl_update_fn( 230 agent_state, opt_state, sample_batches, learn_key 231 ) 232 233 return (key, agent_state, opt_state), train_info 234 235 ( 236 (_, agent_state, opt_state), 237 ( 238 critic_loss, 239 actor_loss, 240 critic_loss_dict, 241 actor_loss_dict, 242 ), 243 ) = scan_and_mean( 244 _sample_and_update_fn, 245 (key, agent_state, opt_state), 246 (), 247 length=self.config.num_rl_updates_per_iter, 248 ) 249 250 # smoothed td3 metrics 251 td3_metrics = TD3TrainMetric( 252 actor_loss=actor_loss, 253 critic_loss=critic_loss, 254 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 255 ) 256 257 return td3_metrics, agent_state, opt_state 258 259 def _rl_injection(self, ec_opt_state, agent_state): 260 # update EC pop center with RL weights 261 262 pop_mean = ec_opt_state.mean 263 rl_actor_params = agent_state.params.actor_params 264 265 # Tips: x = x + stepsize * (y - x) 266 ec_opt_state = ec_opt_state.replace( 267 mean=optax.incremental_update( 268 rl_actor_params, pop_mean, self.config.rl_injection_stepsize 269 ) 270 ) 271 272 return ec_opt_state 273
[docs] 274 def step(self, state: State) -> tuple[MetricBase, State]: 275 pop_size = self.config.pop_size 276 agent_state = state.agent_state 277 opt_state = state.opt_state 278 ec_opt_state = state.ec_opt_state 279 replay_buffer_state = state.replay_buffer_state 280 iterations = state.metrics.iterations + 1 281 282 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split( 283 state.key, num=4 284 ) 285 286 # ======== EC rollout ======== 287 # the trajectory [#pop, T, B, ...] 288 # metrics: [#pop, B] 289 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 290 291 if self.config.mirror_sampling: 292 key, perm_key = jax.random.split(key) 293 pop_actor_params = jtu.tree_map( 294 lambda x, k: jax.random.permutation(k, x, axis=0), 295 pop_actor_params, 296 rng_split_like_tree(perm_key, pop_actor_params), 297 ) 298 299 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params) 300 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout( 301 pop_agent_state, replay_buffer_state, ec_rollout_key 302 ) 303 304 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 305 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 306 307 # ======== RL update ======== 308 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout( 309 agent_state, replay_buffer_state, rl_rollout_key 310 ) 311 312 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 313 rl_sampled_episodes = jnp.uint32(self.config.rollout_episodes) 314 315 td3_metrics, agent_state, opt_state = self._rl_update( 316 agent_state, opt_state, replay_buffer_state, learn_key 317 ) 318 319 # ======== EC update ======== 320 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1) 321 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 322 323 ec_opt_state = jax.lax.cond( 324 iterations % self.config.rl_injection_interval == 0, 325 self._rl_injection, 326 lambda ec_opt_state, agent_state: ec_opt_state, 327 ec_opt_state, 328 agent_state, 329 ) 330 331 train_metrics = ERLTrainMetric( 332 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1), 333 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1), 334 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1), 335 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1), 336 rl_metrics=td3_metrics, 337 ec_info=ec_metrics, 338 rb_size=replay_buffer_state.buffer_size, 339 ) 340 341 sampled_timesteps = ec_sampled_episodes + rl_sampled_timesteps 342 sampled_episodes = ec_sampled_timesteps + rl_sampled_episodes 343 workflow_metrics = state.metrics.replace( 344 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 345 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 346 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps 347 + rl_sampled_timesteps, 348 iterations=iterations, 349 ) 350 351 state = state.replace( 352 key=key, 353 metrics=workflow_metrics, 354 agent_state=agent_state, 355 replay_buffer_state=replay_buffer_state, 356 ec_opt_state=ec_opt_state, 357 opt_state=opt_state, 358 ) 359 360 return train_metrics, state
361
[docs] 362 def evaluate(self, state: State) -> tuple[MetricBase, State]: 363 key, rl_eval_key, ec_eval_key = jax.random.split(state.key, num=3) 364 365 rl_eval_metrics = self.evaluator.evaluate( 366 state.agent_state, rl_eval_key, num_episodes=self.config.eval_episodes 367 ) 368 369 pop_mean_actor_params = state.ec_opt_state.mean 370 371 pop_mean_agent_state = erl_replace_td3_actor_params( 372 state.agent_state, pop_mean_actor_params 373 ) 374 375 ec_eval_metrics = self.evaluator.evaluate( 376 pop_mean_agent_state, ec_eval_key, num_episodes=self.config.eval_episodes 377 ) 378 379 eval_metrics = EvaluateMetric( 380 rl_episode_returns=rl_eval_metrics.episode_returns.mean(), 381 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(), 382 pop_center_episode_returns=ec_eval_metrics.episode_returns.mean(), 383 pop_center_episode_lengths=ec_eval_metrics.episode_lengths.mean(), 384 ) 385 386 state = state.replace(key=key) 387 388 return eval_metrics, state
389
[docs] 390 def learn(self, state: State) -> State: 391 sampled_episodes_per_iter = ( 392 self.config.episodes_for_fitness * self.config.pop_size 393 + self.config.rollout_episodes 394 ) 395 num_iters = math.ceil( 396 (self.config.total_episodes - state.metrics.sampled_episodes) 397 / sampled_episodes_per_iter 398 ) 399 400 final_iteration = num_iters + state.metrics.iterations 401 for i in range(state.metrics.iterations, final_iteration): 402 iters = i + 1 403 train_metrics, state = self.step(state) 404 workflow_metrics = state.metrics 405 406 workflow_metrics_dict = workflow_metrics.to_local_dict() 407 self.recorder.write(workflow_metrics_dict, iters) 408 409 train_metrics_dict = train_metrics.to_local_dict() 410 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 411 train_metrics_dict["pop_episode_returns"], histogram=True 412 ) 413 414 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 415 train_metrics_dict["pop_episode_lengths"], histogram=True 416 ) 417 418 self.recorder.write(train_metrics_dict, iters) 419 420 if iters % self.config.eval_interval == 0 or iters == final_iteration: 421 eval_metrics, state = self.evaluate(state) 422 423 self.recorder.write( 424 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 425 ) 426 427 saved_state = state 428 if not self.config.save_replay_buffer: 429 saved_state = skip_replay_buffer_state(saved_state) 430 431 self.checkpoint_manager.save( 432 iters, 433 saved_state, 434 force=iters == final_iteration, 435 ) 436 437 return state
438 439
[docs] 440def build_rl_update_fn( 441 agent: Agent, 442 optimizer: optax.GradientTransformation, 443 config: DictConfig, 444): 445 def critic_loss_fn(agent_state, sample_batch, key): 446 # loss on a single critic with multiple actors 447 # sample_batch: (B, ...) 448 449 loss_dict = agent.critic_loss(agent_state, sample_batch, key) 450 451 loss = loss_dict.critic_loss 452 453 return loss, loss_dict 454 455 def actor_loss_fn(agent_state, sample_batch, key): 456 # loss on a single actor 457 # different actor shares same sample_batch (B, ...) input 458 loss_dict = agent.actor_loss(agent_state, sample_batch, key) 459 460 loss = loss_dict.actor_loss 461 462 return loss, loss_dict 463 464 critic_update_fn = agent_gradient_update( 465 critic_loss_fn, 466 optimizer, 467 has_aux=True, 468 attach_fn=lambda agent_state, critic_params: agent_state.replace( 469 params=agent_state.params.replace(critic_params=critic_params) 470 ), 471 detach_fn=lambda agent_state: agent_state.params.critic_params, 472 ) 473 474 actor_update_fn = agent_gradient_update( 475 actor_loss_fn, 476 optimizer, 477 has_aux=True, 478 attach_fn=lambda agent_state, actor_params: agent_state.replace( 479 params=agent_state.params.replace(actor_params=actor_params) 480 ), 481 detach_fn=lambda agent_state: agent_state.params.actor_params, 482 ) 483 484 def _update_fn(agent_state, opt_state, sample_batches, key): 485 critic_opt_state = opt_state.critic 486 actor_opt_state = opt_state.actor 487 488 key, critic_key, actor_key = jax.random.split(key, num=3) 489 490 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches) 491 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches) 492 493 if config.actor_update_interval - 1 > 0: 494 495 def _update_critic_fn(carry, sample_batch): 496 key, agent_state, critic_opt_state = carry 497 498 key, critic_key = jax.random.split(key) 499 500 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 501 critic_update_fn( 502 critic_opt_state, agent_state, sample_batch, critic_key 503 ) 504 ) 505 506 return (key, agent_state, critic_opt_state), None 507 508 key, critic_multiple_update_key = jax.random.split(key) 509 510 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 511 _update_critic_fn, 512 ( 513 critic_multiple_update_key, 514 agent_state, 515 critic_opt_state, 516 ), 517 critic_sample_batches, 518 length=config.actor_update_interval - 1, 519 ) 520 521 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 522 critic_update_fn( 523 critic_opt_state, agent_state, last_sample_batch, critic_key 524 ) 525 ) 526 527 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn( 528 actor_opt_state, agent_state, last_sample_batch, actor_key 529 ) 530 531 # not need vmap 532 target_actor_params = soft_target_update( 533 agent_state.params.target_actor_params, 534 agent_state.params.actor_params, 535 config.tau, 536 ) 537 target_critic_params = soft_target_update( 538 agent_state.params.target_critic_params, 539 agent_state.params.critic_params, 540 config.tau, 541 ) 542 agent_state = agent_state.replace( 543 params=agent_state.params.replace( 544 target_actor_params=target_actor_params, 545 target_critic_params=target_critic_params, 546 ) 547 ) 548 549 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state) 550 551 return ( 552 (agent_state, opt_state), 553 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 554 ) 555 556 return _update_fn