Source code for evorl.algorithms.erl.erl_td3.erl_es

  1import logging
  2import math
  3from omegaconf import DictConfig
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10
 11from evorl.replay_buffers import ReplayBuffer
 12from evorl.metrics import MetricBase
 13from evorl.types import PyTreeDict, State
 14from evorl.utils.jax_utils import tree_get
 15from evorl.evaluators import Evaluator, EpisodeCollector
 16from evorl.agent import AgentState
 17from evorl.envs import create_env, AutoresetMode
 18from evorl.recorders import get_1d_array_statistics, add_prefix, get_1d_array
 19from evorl.ec.optimizers import ECState, VanillaESMod, ExponentialScheduleSpec
 20from evorl.algorithms.td3 import make_mlp_td3_agent
 21from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 22
 23from ..erl_workflow import ERLTrainMetric
 24from .erl_td3_workflow import ERLTD3WorkflowTemplate, erl_replace_td3_actor_params
 25
 26logger = logging.getLogger(__name__)
 27
 28
[docs] 29class EvaluateMetric(MetricBase): 30 rl_episode_returns: chex.Array 31 rl_episode_lengths: chex.Array 32 pop_center_episode_returns: chex.Array 33 pop_center_episode_lengths: chex.Array
34 35
[docs] 36class ERLESWorkflow(ERLTD3WorkflowTemplate): 37 """ERL w/ ES. 38 39 Configs: 40 41 - EC: n actors 42 - RL: k actors + k critics 43 - Shared replay buffer 44 """ 45
[docs] 46 @classmethod 47 def name(cls): 48 return "ERL-ES"
49 50 @classmethod 51 def _build_from_config(cls, config: DictConfig): 52 # env for rl&ec rollout 53 env = create_env( 54 config.env, 55 episode_length=config.env.max_episode_steps, 56 parallel=config.num_envs, 57 autoreset_mode=AutoresetMode.DISABLED, 58 record_ori_obs=True, 59 ) 60 61 agent = make_mlp_td3_agent( 62 action_space=env.action_space, 63 norm_layer_type=config.agent_network.norm_layer_type, 64 num_critics=config.agent_network.num_critics, 65 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 66 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 67 discount=config.discount, 68 exploration_epsilon=config.exploration_epsilon, 69 policy_noise=config.policy_noise, 70 clip_policy_noise=config.clip_policy_noise, 71 critics_in_actor_loss=config.critics_in_actor_loss, 72 normalize_obs=config.normalize_obs, 73 ) 74 75 if ( 76 config.optimizer.grad_clip_norm is not None 77 and config.optimizer.grad_clip_norm > 0 78 ): 79 optimizer = optax.chain( 80 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 81 optax.adam(config.optimizer.lr), 82 ) 83 else: 84 optimizer = optax.adam(config.optimizer.lr) 85 86 ec_optimizer = VanillaESMod( 87 pop_size=config.pop_size, 88 external_size=config.num_rl_agents, 89 num_elites=config.num_elites, 90 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std), 91 mix_strategy=config.mix_strategy, 92 ) 93 94 if config.fitness_with_exploration: 95 action_fn = agent.compute_actions 96 else: 97 action_fn = agent.evaluate_actions 98 99 ec_collector = EpisodeCollector( 100 env=env, 101 action_fn=action_fn, 102 max_episode_steps=config.env.max_episode_steps, 103 env_extra_fields=("ori_obs", "termination"), 104 ) 105 106 if config.rl_exploration: 107 action_fn = agent.compute_actions 108 else: 109 action_fn = agent.evaluate_actions 110 111 rl_collector = EpisodeCollector( 112 env=env, 113 action_fn=action_fn, 114 max_episode_steps=config.env.max_episode_steps, 115 env_extra_fields=("ori_obs", "termination"), 116 ) 117 118 replay_buffer = ReplayBuffer( 119 capacity=config.replay_buffer_capacity, 120 min_sample_timesteps=config.batch_size, 121 sample_batch_size=config.batch_size, 122 ) 123 124 # to evaluate the pop-mean actor 125 eval_env = create_env( 126 config.env, 127 episode_length=config.env.max_episode_steps, 128 parallel=config.num_eval_envs, 129 autoreset_mode=AutoresetMode.DISABLED, 130 ) 131 132 evaluator = Evaluator( 133 env=eval_env, 134 action_fn=agent.evaluate_actions, 135 max_episode_steps=config.env.max_episode_steps, 136 ) 137 138 agent_state_vmap_axes = AgentState( 139 params=0, 140 obs_preprocessor_state=None, 141 ) 142 143 workflow = cls( 144 env=env, 145 agent=agent, 146 agent_state_vmap_axes=agent_state_vmap_axes, 147 optimizer=optimizer, 148 ec_optimizer=ec_optimizer, 149 ec_collector=ec_collector, 150 rl_collector=rl_collector, 151 evaluator=evaluator, 152 replay_buffer=replay_buffer, 153 config=config, 154 ) 155 156 return workflow 157 158 def _setup_agent_and_optimizer( 159 self, key: chex.PRNGKey 160 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 161 agent_key, pop_agent_key, ec_key = jax.random.split(key, 3) 162 163 # agent for RL 164 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))( 165 self.env.obs_space, 166 self.env.action_space, 167 jax.random.split(agent_key, self.config.num_rl_agents), 168 ) 169 170 # all agents will share the same obs_preprocessor_state 171 if agent_state.obs_preprocessor_state is not None: 172 agent_state = agent_state.replace( 173 obs_preprocessor_state=tree_get(agent_state.obs_preprocessor_state, 0) 174 ) 175 176 dummy_obs = self.env.obs_space.sample(key) 177 init_actor_params = self.agent.actor_network.init( 178 pop_agent_key, jtu.tree_map(lambda x: x[None, ...], dummy_obs) 179 ) 180 181 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 182 183 opt_state = PyTreeDict( 184 actor=self.optimizer.init(agent_state.params.actor_params), 185 critic=self.optimizer.init(agent_state.params.critic_params), 186 ) 187 188 return agent_state, opt_state, ec_opt_state 189 190 def _rl_injection( 191 self, 192 ec_opt_state: ECState, 193 agent_state: AgentState, 194 ec_fitnesses: chex.Array, 195 rl_fitnesses: chex.Array, 196 ) -> tuple[chex.Array, ECState]: 197 rl_noise = jtu.tree_map( 198 lambda x, m: x - m, 199 agent_state.params.actor_params, 200 ec_opt_state.mean, 201 ) 202 203 concat_noise = jtu.tree_map( 204 lambda n1, n2: jnp.concatenate([n1, n2], axis=0), 205 ec_opt_state.noise, 206 rl_noise, 207 ) 208 209 ec_opt_state = ec_opt_state.replace(noise=concat_noise) 210 211 fitnesses = jnp.concatenate([ec_fitnesses, rl_fitnesses], axis=0) 212 213 return fitnesses, ec_opt_state 214
[docs] 215 def step(self, state: State) -> tuple[MetricBase, State]: 216 pop_size = self.config.pop_size 217 agent_state = state.agent_state 218 opt_state = state.opt_state 219 ec_opt_state = state.ec_opt_state 220 replay_buffer_state = state.replay_buffer_state 221 iterations = state.metrics.iterations + 1 222 223 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split( 224 state.key, num=4 225 ) 226 227 # ======== EC & RL rollout ======== 228 # the trajectory [#pop, T, B, ...] 229 # metrics: [#pop, B] 230 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 231 232 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params) 233 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout( 234 pop_agent_state, replay_buffer_state, ec_rollout_key 235 ) 236 237 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 238 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 239 240 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout( 241 agent_state, replay_buffer_state, rl_rollout_key 242 ) 243 244 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 245 rl_sampled_episodes = jnp.uint32( 246 self.config.num_rl_agents * self.config.rollout_episodes 247 ) 248 249 train_metrics = ERLTrainMetric( 250 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1), 251 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1), 252 ) 253 254 # ======== RL update ======== 255 td3_metrics, agent_state, opt_state = self._rl_update( 256 agent_state, opt_state, replay_buffer_state, learn_key 257 ) 258 259 # get average loss 260 td3_metrics = td3_metrics.replace( 261 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents, 262 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents, 263 ) 264 265 train_metrics = train_metrics.replace( 266 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1), 267 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1), 268 rl_metrics=td3_metrics, 269 ) 270 271 # ======== EC update ======== 272 # inject RL into EC 273 274 ec_fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1) 275 rl_fitnesses = rl_eval_metrics.episode_returns.mean(axis=-1) 276 fitnesses, ec_opt_state = self._rl_injection( 277 ec_opt_state, agent_state, ec_fitnesses, rl_fitnesses 278 ) 279 280 ec_metrics, ec_opt_state = self.ec_optimizer.tell_external( 281 ec_opt_state, fitnesses 282 ) 283 284 train_metrics = train_metrics.replace( 285 ec_info=ec_metrics, rb_size=replay_buffer_state.buffer_size 286 ) 287 288 # calculate the number of timestep 289 sampled_timesteps = ec_sampled_timesteps + rl_sampled_timesteps 290 sampled_episodes = ec_sampled_episodes + rl_sampled_episodes 291 workflow_metrics = state.metrics.replace( 292 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 293 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 294 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps 295 + rl_sampled_timesteps, 296 iterations=iterations, 297 ) 298 299 state = state.replace( 300 key=key, 301 metrics=workflow_metrics, 302 agent_state=agent_state, 303 replay_buffer_state=replay_buffer_state, 304 ec_opt_state=ec_opt_state, 305 opt_state=opt_state, 306 ) 307 308 return train_metrics, state
309
[docs] 310 def evaluate(self, state: State) -> tuple[MetricBase, State]: 311 key, rl_eval_key, ec_eval_key = jax.random.split(state.key, num=3) 312 313 rl_eval_metrics = jax.vmap( 314 self.evaluator.evaluate, in_axes=(self.agent_state_vmap_axes, 0, None) 315 )( 316 state.agent_state, 317 jax.random.split(rl_eval_key, num=self.config.num_rl_agents), 318 self.config.eval_episodes, 319 ) 320 321 pop_mean_actor_params = state.ec_opt_state.mean 322 323 pop_mean_agent_state = erl_replace_td3_actor_params( 324 state.agent_state, pop_mean_actor_params 325 ) 326 327 ec_eval_metrics = self.evaluator.evaluate( 328 pop_mean_agent_state, ec_eval_key, num_episodes=self.config.eval_episodes 329 ) 330 331 eval_metrics = EvaluateMetric( 332 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1), 333 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1), 334 pop_center_episode_returns=ec_eval_metrics.episode_returns.mean(), 335 pop_center_episode_lengths=ec_eval_metrics.episode_lengths.mean(), 336 ) 337 338 state = state.replace(key=key) 339 340 return eval_metrics, state
341
[docs] 342 def learn(self, state: State) -> State: 343 sampled_episodes_per_iter = ( 344 self.config.episodes_for_fitness * self.config.pop_size 345 + self.config.rollout_episodes * self.config.num_rl_agents 346 ) 347 num_iters = math.ceil( 348 (self.config.total_episodes - state.metrics.sampled_episodes) 349 / sampled_episodes_per_iter 350 ) 351 352 final_iteration = num_iters + state.metrics.iterations 353 for i in range(state.metrics.iterations, final_iteration): 354 iters = i + 1 355 train_metrics, state = self.step(state) 356 workflow_metrics = state.metrics 357 358 workflow_metrics_dict = workflow_metrics.to_local_dict() 359 self.recorder.write(workflow_metrics_dict, iters) 360 361 train_metrics_dict = train_metrics.to_local_dict() 362 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 363 train_metrics_dict["pop_episode_returns"], histogram=True 364 ) 365 366 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 367 train_metrics_dict["pop_episode_lengths"], histogram=True 368 ) 369 370 if self.config.num_rl_agents > 1: 371 train_metrics_dict["rl_episode_lengths"] = get_1d_array_statistics( 372 train_metrics_dict["rl_episode_lengths"], histogram=True 373 ) 374 train_metrics_dict["rl_episode_returns"] = get_1d_array_statistics( 375 train_metrics_dict["rl_episode_returns"], histogram=True 376 ) 377 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map( 378 get_1d_array_statistics, 379 train_metrics_dict["rl_metrics"]["raw_loss_dict"], 380 ) 381 else: 382 train_metrics_dict["rl_episode_lengths"] = train_metrics_dict[ 383 "rl_episode_lengths" 384 ].squeeze(0) 385 train_metrics_dict["rl_episode_returns"] = train_metrics_dict[ 386 "rl_episode_returns" 387 ].squeeze(0) 388 389 self.recorder.write(train_metrics_dict, iters) 390 391 if iters % self.config.eval_interval == 0 or iters == final_iteration: 392 eval_metrics, state = self.evaluate(state) 393 394 eval_metrics_dict = eval_metrics.to_local_dict() 395 if self.config.num_rl_agents > 1: 396 eval_metrics_dict = jtu.tree_map(get_1d_array, eval_metrics_dict) 397 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 398 399 saved_state = state 400 if not self.config.save_replay_buffer: 401 saved_state = skip_replay_buffer_state(saved_state) 402 403 self.checkpoint_manager.save( 404 iters, 405 saved_state, 406 force=iters == final_iteration, 407 ) 408 409 return state