Source code for evorl.algorithms.erl.erl_td3.erl_ga

  1import logging
  2import math
  3from omegaconf import DictConfig
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9import optax
 10
 11from evorl.replay_buffers import ReplayBuffer
 12from evorl.metrics import MetricBase
 13from evorl.types import PyTreeDict, State
 14from evorl.evaluators import Evaluator, EpisodeCollector
 15from evorl.agent import AgentState
 16from evorl.envs import create_env, AutoresetMode
 17from evorl.recorders import get_1d_array_statistics, add_prefix, get_1d_array
 18from evorl.ec.optimizers import ERLGAMod, ECState
 19from evorl.algorithms.td3 import make_mlp_td3_agent
 20from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 21
 22from ..erl_workflow import ERLTrainMetric
 23from .erl_td3_workflow import ERLTD3WorkflowTemplate, erl_replace_td3_actor_params
 24
 25logger = logging.getLogger(__name__)
 26
 27
[docs] 28class EvaluateMetric(MetricBase): 29 rl_episode_returns: chex.Array 30 rl_episode_lengths: chex.Array
31 32
[docs] 33class ERLGAWorkflow(ERLTD3WorkflowTemplate): 34 """ERL w/ GA. 35 36 Config: 37 38 - EC: n actors 39 - RL: k actors + k critics 40 - Shared replay buffer 41 """ 42
[docs] 43 @classmethod 44 def name(cls): 45 return "ERL-GA"
46 47 @classmethod 48 def _build_from_config(cls, config: DictConfig): 49 assert config.num_elites >= config.num_rl_agents 50 51 # env for rl&ec rollout 52 env = create_env( 53 config.env, 54 episode_length=config.env.max_episode_steps, 55 parallel=config.num_envs, 56 autoreset_mode=AutoresetMode.DISABLED, 57 record_ori_obs=True, 58 ) 59 60 agent = make_mlp_td3_agent( 61 action_space=env.action_space, 62 norm_layer_type=config.agent_network.norm_layer_type, 63 num_critics=config.agent_network.num_critics, 64 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 65 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 66 discount=config.discount, 67 exploration_epsilon=config.exploration_epsilon, 68 policy_noise=config.policy_noise, 69 clip_policy_noise=config.clip_policy_noise, 70 critics_in_actor_loss=config.critics_in_actor_loss, 71 normalize_obs=config.normalize_obs, 72 ) 73 74 if ( 75 config.optimizer.grad_clip_norm is not None 76 and config.optimizer.grad_clip_norm > 0 77 ): 78 optimizer = optax.chain( 79 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 80 optax.adam(config.optimizer.lr), 81 ) 82 else: 83 optimizer = optax.adam(config.optimizer.lr) 84 85 ec_optimizer = ERLGAMod( 86 pop_size=config.pop_size, 87 num_elites=config.num_elites, 88 external_size=config.num_rl_agents, 89 weight_max_magnitude=config.weight_max_magnitude, 90 mut_strength=config.mut_strength, 91 num_mutation_frac=config.num_mutation_frac, 92 super_mut_strength=config.super_mut_strength, 93 super_mut_prob=config.super_mut_prob, 94 reset_prob=config.reset_prob, 95 vec_relative_prob=config.vec_relative_prob, 96 enable_crossover=config.enable_crossover, 97 num_crossover_frac=config.num_crossover_frac, 98 ) 99 100 if config.fitness_with_exploration: 101 action_fn = agent.compute_actions 102 else: 103 action_fn = agent.evaluate_actions 104 105 ec_collector = EpisodeCollector( 106 env=env, 107 action_fn=action_fn, 108 max_episode_steps=config.env.max_episode_steps, 109 env_extra_fields=("ori_obs", "termination"), 110 ) 111 112 if config.rl_exploration: 113 action_fn = agent.compute_actions 114 else: 115 action_fn = agent.evaluate_actions 116 117 rl_collector = EpisodeCollector( 118 env=env, 119 action_fn=action_fn, 120 max_episode_steps=config.env.max_episode_steps, 121 env_extra_fields=("ori_obs", "termination"), 122 ) 123 124 replay_buffer = ReplayBuffer( 125 capacity=config.replay_buffer_capacity, 126 min_sample_timesteps=config.batch_size, 127 sample_batch_size=config.batch_size, 128 ) 129 130 # to evaluate the pop-mean actor 131 eval_env = create_env( 132 config.env, 133 episode_length=config.env.max_episode_steps, 134 parallel=config.num_eval_envs, 135 autoreset_mode=AutoresetMode.DISABLED, 136 ) 137 138 evaluator = Evaluator( 139 env=eval_env, 140 action_fn=agent.evaluate_actions, 141 max_episode_steps=config.env.max_episode_steps, 142 ) 143 144 agent_state_vmap_axes = AgentState( 145 params=0, 146 obs_preprocessor_state=None, 147 ) 148 149 workflow = cls( 150 env=env, 151 agent=agent, 152 agent_state_vmap_axes=agent_state_vmap_axes, 153 optimizer=optimizer, 154 ec_optimizer=ec_optimizer, 155 ec_collector=ec_collector, 156 rl_collector=rl_collector, 157 evaluator=evaluator, 158 replay_buffer=replay_buffer, 159 config=config, 160 ) 161 162 return workflow 163 164 def _setup_agent_and_optimizer( 165 self, key: chex.PRNGKey 166 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 167 agent_key, pop_agent_key, ec_key = jax.random.split(key, 3) 168 169 # agent for RL 170 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))( 171 self.env.obs_space, 172 self.env.action_space, 173 jax.random.split(agent_key, self.config.num_rl_agents), 174 ) 175 176 # all agents will share the same obs_preprocessor_state 177 if agent_state.obs_preprocessor_state is not None: 178 agent_state = agent_state.replace( 179 obs_preprocessor_state=jtu.tree_map( 180 lambda x: x[0], agent_state.obs_preprocessor_state 181 ) 182 ) 183 184 dummy_obs = self.env.obs_space.sample(key) 185 pop_actor_params = jax.vmap(self.agent.actor_network.init, in_axes=(0, None))( 186 jax.random.split(pop_agent_key, self.config.pop_size), 187 jtu.tree_map(lambda x: x[None, ...], dummy_obs), 188 ) 189 190 ec_opt_state = self.ec_optimizer.init(pop_actor_params, ec_key) 191 192 opt_state = PyTreeDict( 193 actor=self.optimizer.init(agent_state.params.actor_params), 194 critic=self.optimizer.init(agent_state.params.critic_params), 195 ) 196 197 return agent_state, opt_state, ec_opt_state 198 199 def _rl_injection(self, ec_opt_state: ECState, agent_state: AgentState) -> ECState: 200 rl_actor_params = agent_state.params.actor_params 201 chex.assert_tree_shape_prefix(rl_actor_params, (self.config.num_rl_agents,)) 202 203 ec_opt_state = ec_opt_state.replace(external_pop=rl_actor_params) 204 205 return ec_opt_state 206
[docs] 207 def step(self, state: State) -> tuple[MetricBase, State]: 208 """The basic step function for the workflow to update agent.""" 209 pop_size = self.config.pop_size 210 agent_state = state.agent_state 211 opt_state = state.opt_state 212 ec_opt_state = state.ec_opt_state 213 replay_buffer_state = state.replay_buffer_state 214 iterations = state.metrics.iterations + 1 215 216 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split( 217 state.key, num=4 218 ) 219 220 # ======== EC rollout ======== 221 # the trajectory [#pop, T, B, ...] 222 # metrics: [#pop, B] 223 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 224 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params) 225 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout( 226 pop_agent_state, replay_buffer_state, ec_rollout_key 227 ) 228 229 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 230 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 231 232 # ======== RL update ======== 233 234 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout( 235 agent_state, replay_buffer_state, rl_rollout_key 236 ) 237 238 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 239 rl_sampled_episodes = jnp.uint32( 240 self.config.num_rl_agents * self.config.rollout_episodes 241 ) 242 243 td3_metrics, agent_state, opt_state = self._rl_update( 244 agent_state, opt_state, replay_buffer_state, learn_key 245 ) 246 247 # get average loss 248 td3_metrics = td3_metrics.replace( 249 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents, 250 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents, 251 ) 252 253 # ====== EC update ====== 254 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1) 255 256 def _tell_with_rl_injection(ec_opt_state, fitnesses): 257 ec_opt_state = self._rl_injection(ec_opt_state, agent_state) 258 ec_metrics, ec_opt_state = self.ec_optimizer.tell_external( 259 ec_opt_state, fitnesses 260 ) 261 return ec_metrics, ec_opt_state 262 263 ec_metrics, ec_opt_state = jax.lax.cond( 264 iterations % self.config.rl_injection_interval == 0, 265 _tell_with_rl_injection, 266 self.ec_optimizer.tell, 267 ec_opt_state, 268 fitnesses, 269 ) 270 271 train_metrics = ERLTrainMetric( 272 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1), 273 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1), 274 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1), 275 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1), 276 rl_metrics=td3_metrics, 277 ec_info=ec_metrics, 278 rb_size=replay_buffer_state.buffer_size, 279 ) 280 281 # calculate the number of timestep 282 sampled_timesteps = ec_sampled_timesteps + rl_sampled_timesteps 283 sampled_episodes = ec_sampled_episodes + rl_sampled_episodes 284 workflow_metrics = state.metrics.replace( 285 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 286 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 287 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps 288 + rl_sampled_timesteps, 289 iterations=iterations, 290 ) 291 292 state = state.replace( 293 key=key, 294 metrics=workflow_metrics, 295 agent_state=agent_state, 296 replay_buffer_state=replay_buffer_state, 297 ec_opt_state=ec_opt_state, 298 opt_state=opt_state, 299 ) 300 301 return train_metrics, state
302
[docs] 303 def evaluate(self, state: State) -> tuple[MetricBase, State]: 304 key, eval_key = jax.random.split(state.key, num=2) 305 306 # [num_rl_agents, #episodes] 307 raw_eval_metrics = jax.vmap( 308 self.evaluator.evaluate, in_axes=(self.agent_state_vmap_axes, 0, None) 309 )( 310 state.agent_state, 311 jax.random.split(eval_key, self.config.num_rl_agents), 312 self.config.eval_episodes, 313 ) 314 315 eval_metrics = EvaluateMetric( 316 rl_episode_returns=raw_eval_metrics.episode_returns.mean(-1), 317 rl_episode_lengths=raw_eval_metrics.episode_lengths.mean(-1), 318 ) 319 320 state = state.replace(key=key) 321 return eval_metrics, state
322
[docs] 323 def learn(self, state: State) -> State: 324 sampled_episodes_per_iter = ( 325 self.config.episodes_for_fitness * self.config.pop_size 326 + self.config.rollout_episodes * self.config.num_rl_agents 327 ) 328 num_iters = math.ceil( 329 (self.config.total_episodes - state.metrics.sampled_episodes) 330 / sampled_episodes_per_iter 331 ) 332 333 final_iteration = num_iters + state.metrics.iterations 334 for i in range(state.metrics.iterations, final_iteration): 335 iters = i + 1 336 train_metrics, state = self.step(state) 337 workflow_metrics = state.metrics 338 339 workflow_metrics_dict = workflow_metrics.to_local_dict() 340 self.recorder.write(workflow_metrics_dict, iters) 341 342 train_metrics_dict = train_metrics.to_local_dict() 343 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 344 train_metrics_dict["pop_episode_returns"], histogram=True 345 ) 346 347 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 348 train_metrics_dict["pop_episode_lengths"], histogram=True 349 ) 350 351 if self.config.num_rl_agents > 1: 352 train_metrics_dict["rl_episode_lengths"] = get_1d_array_statistics( 353 train_metrics_dict["rl_episode_lengths"], histogram=True 354 ) 355 train_metrics_dict["rl_episode_returns"] = get_1d_array_statistics( 356 train_metrics_dict["rl_episode_returns"], histogram=True 357 ) 358 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map( 359 get_1d_array_statistics, 360 train_metrics_dict["rl_metrics"]["raw_loss_dict"], 361 ) 362 else: 363 train_metrics_dict["rl_episode_lengths"] = train_metrics_dict[ 364 "rl_episode_lengths" 365 ].squeeze(0) 366 train_metrics_dict["rl_episode_returns"] = train_metrics_dict[ 367 "rl_episode_returns" 368 ].squeeze(0) 369 370 self.recorder.write(train_metrics_dict, iters) 371 372 if iters % self.config.eval_interval == 0 or iters == final_iteration: 373 eval_metrics, state = self.evaluate(state) 374 375 eval_metrics_dict = eval_metrics.to_local_dict() 376 if self.config.num_rl_agents > 1: 377 eval_metrics_dict = jtu.tree_map(get_1d_array, eval_metrics_dict) 378 379 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters) 380 381 saved_state = state 382 if not self.config.save_replay_buffer: 383 saved_state = skip_replay_buffer_state(saved_state) 384 385 self.checkpoint_manager.save( 386 iters, 387 saved_state, 388 force=iters == final_iteration, 389 ) 390 391 return state