Source code for evorl.algorithms.contrib.pop_episodic_td3

  1from functools import partial
  2import math
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4from omegaconf import DictConfig
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11
 12from evorl.distributed import agent_gradient_update
 13from evorl.agent import Agent, AgentState
 14from evorl.types import PyTreeDict, State
 15from evorl.metrics import MetricBase, EvaluateMetric
 16from evorl.replay_buffers import ReplayBuffer
 17from evorl.envs import create_env, AutoresetMode
 18from evorl.utils.rl_toolkits import (
 19    soft_target_update,
 20)
 21from evorl.recorders import add_prefix, get_1d_array_statistics, get_1d_array
 22from evorl.evaluators import Evaluator, EpisodeCollector
 23
 24
 25from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 26from evorl.algorithms.td3 import make_mlp_td3_agent
 27from evorl.algorithms.erl.cemrl_td3.cemrl_td3_workflow import CEMRLTD3WorkflowTemplate
 28from evorl.algorithms.erl.cemrl_workflow import CEMRLTrainMetric
 29
 30
[docs] 31class PopEpisodicTD3Workflow(CEMRLTD3WorkflowTemplate): 32 """A batched TD3 workflow like CEMERL. 33 34 The differences from CEMRL are: 35 - Each individual has an actor and a critic. 36 - All individuals are updated by RL. 37 """ 38 39 def __init__(self, **kwargs): 40 super(CEMRLTD3WorkflowTemplate, self).__init__(**kwargs) 41 self._rl_update_fn = build_rl_update_fn( 42 self.agent, 43 self.optimizer, 44 self.config, 45 self.agent_state_vmap_axes, 46 ) 47
[docs] 48 @classmethod 49 def name(cls): 50 return "PopEpisodicTD3"
51 52 @classmethod 53 def _build_from_config(cls, config: DictConfig) -> Self: 54 assert config.random_timesteps > 0, ( 55 "random_timesteps should be positive to pre-fill some data in the replay buffer" 56 ) 57 58 assert config.pop_size == config.num_learning_offspring, ( 59 "pop_size must equal to num_learning_offspring" 60 ) 61 62 # env for one actor 63 env = create_env( 64 config.env, 65 episode_length=config.env.max_episode_steps, 66 parallel=config.num_envs, 67 autoreset_mode=AutoresetMode.DISABLED, 68 record_ori_obs=True, 69 ) 70 71 agent = make_mlp_td3_agent( 72 action_space=env.action_space, 73 norm_layer_type=config.agent_network.norm_layer_type, 74 num_critics=config.agent_network.num_critics, 75 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 76 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 77 discount=config.discount, 78 exploration_epsilon=config.exploration_epsilon, 79 policy_noise=config.policy_noise, 80 clip_policy_noise=config.clip_policy_noise, 81 critics_in_actor_loss=config.critics_in_actor_loss, 82 normalize_obs=config.normalize_obs, 83 ) 84 85 if ( 86 config.optimizer.grad_clip_norm is not None 87 and config.optimizer.grad_clip_norm > 0 88 ): 89 optimizer = optax.chain( 90 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 91 optax.adam(config.optimizer.lr), 92 ) 93 else: 94 optimizer = optax.adam(config.optimizer.lr) 95 96 if config.fitness_with_exploration: 97 action_fn = agent.compute_actions 98 else: 99 action_fn = agent.evaluate_actions 100 101 collector = EpisodeCollector( 102 env=env, 103 action_fn=action_fn, 104 max_episode_steps=config.env.max_episode_steps, 105 env_extra_fields=("ori_obs", "termination"), 106 ) 107 108 replay_buffer = ReplayBuffer( 109 capacity=config.replay_buffer_capacity, 110 min_sample_timesteps=config.batch_size, 111 sample_batch_size=config.batch_size, 112 ) 113 114 # to evaluate the pop-mean actor 115 eval_env = create_env( 116 config.env, 117 episode_length=config.env.max_episode_steps, 118 parallel=config.num_eval_envs, 119 autoreset_mode=AutoresetMode.DISABLED, 120 ) 121 122 evaluator = Evaluator( 123 env=eval_env, 124 action_fn=agent.evaluate_actions, 125 max_episode_steps=config.env.max_episode_steps, 126 ) 127 128 agent_state_vmap_axes = AgentState( 129 params=0, 130 obs_preprocessor_state=None, # shared 131 ) 132 133 workflow = cls( 134 env=env, 135 agent=agent, 136 agent_state_vmap_axes=agent_state_vmap_axes, 137 optimizer=optimizer, 138 ec_optimizer=None, 139 collector=collector, 140 evaluator=evaluator, 141 replay_buffer=replay_buffer, 142 config=config, 143 ) 144 145 return workflow 146 147 def _setup_agent_and_optimizer(self, key: chex.PRNGKey): 148 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))( 149 self.env.obs_space, 150 self.env.action_space, 151 jax.random.split(key, self.config.pop_size), 152 ) 153 154 opt_state = PyTreeDict( 155 actor=self.optimizer.init(agent_state.params.actor_params), 156 critic=self.optimizer.init(agent_state.params.critic_params), 157 ) 158 159 ec_opt_state = None 160 161 return agent_state, opt_state, ec_opt_state 162
[docs] 163 def step(self, state: State) -> tuple[MetricBase, State]: 164 pop_size = self.config.pop_size 165 agent_state = state.agent_state 166 opt_state = state.opt_state 167 ec_opt_state = state.ec_opt_state 168 replay_buffer_state = state.replay_buffer_state 169 iterations = state.metrics.iterations + 1 170 171 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 172 173 # ======== RL update ======== 174 td3_metrics, agent_state, opt_state = self._rl_update( 175 agent_state, 176 opt_state, 177 replay_buffer_state, 178 learn_key, 179 ) 180 181 # the trajectory [T, #pop*B, ...] 182 # metrics: [#pop, B] 183 eval_metrics, trajectory, replay_buffer_state = self._rollout( 184 agent_state, replay_buffer_state, rollout_key 185 ) 186 187 train_metrics = CEMRLTrainMetric( 188 rb_size=replay_buffer_state.buffer_size, 189 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1), 190 pop_episode_returns=eval_metrics.episode_returns.mean(-1), 191 rl_metrics=td3_metrics, 192 ) 193 194 # calculate the number of timestep 195 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32) 196 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 197 198 workflow_metrics = state.metrics.replace( 199 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 200 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 201 iterations=iterations, 202 ) 203 204 state = state.replace( 205 key=key, 206 metrics=workflow_metrics, 207 agent_state=agent_state, 208 replay_buffer_state=replay_buffer_state, 209 ec_opt_state=ec_opt_state, 210 opt_state=opt_state, 211 ) 212 213 return train_metrics, state
214
[docs] 215 def evaluate(self, state: State) -> tuple[MetricBase, State]: 216 key, eval_key = jax.random.split(state.key, num=2) 217 218 # [#pop, #episodes] 219 raw_eval_metrics = jax.vmap( 220 partial(self.evaluator.evaluate, num_episodes=self.config.eval_episodes), 221 in_axes=(self.agent_state_vmap_axes, 0), 222 )( 223 state.agent_state, 224 jax.random.split(eval_key, self.config.pop_size), 225 ) 226 227 eval_metrics = EvaluateMetric( 228 episode_returns=raw_eval_metrics.episode_returns.mean(-1), 229 episode_lengths=raw_eval_metrics.episode_lengths.mean(-1), 230 ) 231 232 state = state.replace(key=key) 233 return eval_metrics, state
234
[docs] 235 def learn(self, state: State) -> State: 236 num_iters = math.ceil( 237 (self.config.total_episodes - state.metrics.sampled_episodes) 238 / (self.config.episodes_for_fitness * self.config.pop_size) 239 ) 240 241 final_iteration = num_iters + state.metrics.iterations 242 for i in range(state.metrics.iterations, final_iteration): 243 iters = i + 1 244 train_metrics, state = self.step(state) 245 workflow_metrics = state.metrics 246 247 workflow_metrics_dict = workflow_metrics.to_local_dict() 248 self.recorder.write(workflow_metrics_dict, iters) 249 250 train_metrics_dict = train_metrics.to_local_dict() 251 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 252 train_metrics_dict["pop_episode_returns"], histogram=True 253 ) 254 255 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 256 train_metrics_dict["pop_episode_lengths"], histogram=True 257 ) 258 259 if train_metrics_dict["rl_metrics"] is not None: 260 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map( 261 get_1d_array_statistics, 262 train_metrics_dict["rl_metrics"]["raw_loss_dict"], 263 ) 264 265 self.recorder.write(train_metrics_dict, iters) 266 267 if iters % self.config.eval_interval == 0 or iters == final_iteration: 268 eval_metrics, state = self.evaluate(state) 269 270 eval_metrics_dict = jtu.tree_map( 271 get_1d_array, 272 eval_metrics.to_local_dict(), 273 ) 274 275 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 276 277 saved_state = state 278 if not self.config.save_replay_buffer: 279 saved_state = skip_replay_buffer_state(saved_state) 280 281 self.checkpoint_manager.save( 282 iters, 283 saved_state, 284 force=iters == final_iteration, 285 ) 286 287 return state
288 289
[docs] 290def build_rl_update_fn( 291 agent: Agent, 292 optimizer: optax.GradientTransformation, 293 config: DictConfig, 294 agent_state_vmap_axes: AgentState, 295): 296 """K actors + 1 shared critic.""" 297 num_learning_offspring = config.num_learning_offspring 298 299 def critic_loss_fn(agent_state, sample_batch, key): 300 # loss on a single critic with multiple actors 301 # sample_batch: (n, B, ...) 302 303 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 304 agent_state, sample_batch, jax.random.split(key, num_learning_offspring) 305 ) 306 307 # ++++++++++ diff +++++++++ 308 loss = loss_dict.critic_loss.sum() 309 # +++++++++++++++++++++++++ 310 311 return loss, loss_dict 312 313 def actor_loss_fn(agent_state, sample_batch, key): 314 # loss on a single actor 315 316 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 317 agent_state, sample_batch, jax.random.split(key, num_learning_offspring) 318 ) 319 320 # sum over the num_learning_offspring 321 loss = loss_dict.actor_loss.sum() 322 323 return loss, loss_dict 324 325 critic_update_fn = agent_gradient_update( 326 critic_loss_fn, 327 optimizer, 328 has_aux=True, 329 attach_fn=lambda agent_state, critic_params: agent_state.replace( 330 params=agent_state.params.replace(critic_params=critic_params) 331 ), 332 detach_fn=lambda agent_state: agent_state.params.critic_params, 333 ) 334 335 actor_update_fn = agent_gradient_update( 336 actor_loss_fn, 337 optimizer, 338 has_aux=True, 339 attach_fn=lambda agent_state, actor_params: agent_state.replace( 340 params=agent_state.params.replace(actor_params=actor_params) 341 ), 342 detach_fn=lambda agent_state: agent_state.params.actor_params, 343 ) 344 345 def _update_fn(agent_state, opt_state, sample_batches, key): 346 critic_opt_state = opt_state.critic 347 actor_opt_state = opt_state.actor 348 349 key, critic_key, actor_key = jax.random.split(key, num=3) 350 351 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches) 352 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches) 353 354 if config.actor_update_interval - 1 > 0: 355 356 def _update_critic_fn(carry, sample_batch): 357 key, agent_state, critic_opt_state = carry 358 359 key, critic_key = jax.random.split(key) 360 361 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 362 critic_update_fn( 363 critic_opt_state, agent_state, sample_batch, critic_key 364 ) 365 ) 366 367 return (key, agent_state, critic_opt_state), None 368 369 key, critic_multiple_update_key = jax.random.split(key) 370 371 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 372 _update_critic_fn, 373 ( 374 critic_multiple_update_key, 375 agent_state, 376 critic_opt_state, 377 ), 378 critic_sample_batches, 379 length=config.actor_update_interval - 1, 380 ) 381 382 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 383 critic_update_fn( 384 critic_opt_state, agent_state, last_sample_batch, critic_key 385 ) 386 ) 387 388 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn( 389 actor_opt_state, agent_state, last_sample_batch, actor_key 390 ) 391 392 # not need vmap 393 target_actor_params = soft_target_update( 394 agent_state.params.target_actor_params, 395 agent_state.params.actor_params, 396 config.tau, 397 ) 398 target_critic_params = soft_target_update( 399 agent_state.params.target_critic_params, 400 agent_state.params.critic_params, 401 config.tau, 402 ) 403 agent_state = agent_state.replace( 404 params=agent_state.params.replace( 405 target_actor_params=target_actor_params, 406 target_critic_params=target_critic_params, 407 ) 408 ) 409 410 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state) 411 412 return ( 413 (agent_state, opt_state), 414 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 415 ) 416 417 return _update_fn