Source code for evorl.algorithms.erl.cemrl_td3.cemrl

  1import math
  2import numpy as np
  3from typing_extensions import Self  # pytype: disable=not-supported-yet]
  4from omegaconf import DictConfig
  5
  6import chex
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11
 12from evorl.replay_buffers import ReplayBuffer
 13from evorl.metrics import MetricBase
 14from evorl.types import PyTreeDict, State, Params
 15from evorl.utils.jax_utils import tree_get, tree_set
 16from evorl.evaluators import Evaluator, EpisodeCollector
 17from evorl.agent import AgentState
 18from evorl.envs import create_env, AutoresetMode
 19from evorl.recorders import get_1d_array_statistics, add_prefix
 20from evorl.ec.optimizers import SepCEM, ECState, ExponentialScheduleSpec
 21from evorl.algorithms.td3 import make_mlp_td3_agent, TD3NetworkParams
 22from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
 23
 24from ..cemrl_workflow import CEMRLTrainMetric
 25from .cemrl_td3_workflow import (
 26    CEMRLTD3WorkflowTemplate,
 27    cemrl_replace_td3_actor_params,
 28    create_dummy_td3_trainmetric,
 29)
 30
 31
[docs] 32class EvaluateMetric(MetricBase): 33 pop_center_episode_returns: chex.Array 34 pop_center_episode_lengths: chex.Array
35 36
[docs] 37class CEMRLWorkflow(CEMRLTD3WorkflowTemplate): 38 """1 critic + n actors + 1 replay buffer. 39 40 We use shard_map to split and parallel the population. 41 """ 42
[docs] 43 @classmethod 44 def name(cls): 45 return "CEMRL"
46 47 @classmethod 48 def _build_from_config(cls, config: DictConfig) -> Self: 49 assert config.warmup_iters > 0 or config.random_timesteps > 0, ( 50 "Either warmup_iters or random_timesteps should be positive to pre-fill some data in the replay buffer" 51 ) 52 53 # env for one actor 54 env = create_env( 55 config.env, 56 episode_length=config.env.max_episode_steps, 57 parallel=config.num_envs, 58 autoreset_mode=AutoresetMode.DISABLED, 59 record_ori_obs=True, 60 ) 61 62 agent = make_mlp_td3_agent( 63 action_space=env.action_space, 64 norm_layer_type=config.agent_network.norm_layer_type, 65 num_critics=config.agent_network.num_critics, 66 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 67 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 68 discount=config.discount, 69 exploration_epsilon=config.exploration_epsilon, 70 policy_noise=config.policy_noise, 71 clip_policy_noise=config.clip_policy_noise, 72 critics_in_actor_loss=config.critics_in_actor_loss, 73 normalize_obs=config.normalize_obs, 74 ) 75 76 if ( 77 config.optimizer.grad_clip_norm is not None 78 and config.optimizer.grad_clip_norm > 0 79 ): 80 optimizer = optax.chain( 81 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 82 optax.adam(config.optimizer.lr), 83 ) 84 else: 85 optimizer = optax.adam(config.optimizer.lr) 86 87 ec_optimizer = SepCEM( 88 pop_size=config.pop_size, 89 num_elites=config.num_elites, 90 cov_eps_schedule=ExponentialScheduleSpec(**config.cov_eps), 91 weighted_update=config.weighted_update, 92 rank_weight_shift=config.rank_weight_shift, 93 mirror_sampling=config.mirror_sampling, 94 ) 95 96 if config.fitness_with_exploration: 97 action_fn = agent.compute_actions 98 else: 99 action_fn = agent.evaluate_actions 100 101 collector = EpisodeCollector( 102 env=env, 103 action_fn=action_fn, 104 max_episode_steps=config.env.max_episode_steps, 105 env_extra_fields=("ori_obs", "termination"), 106 ) 107 108 replay_buffer = ReplayBuffer( 109 capacity=config.replay_buffer_capacity, 110 min_sample_timesteps=config.batch_size, 111 sample_batch_size=config.batch_size, 112 ) 113 114 # to evaluate the pop-mean actor 115 eval_env = create_env( 116 config.env, 117 episode_length=config.env.max_episode_steps, 118 parallel=config.num_eval_envs, 119 autoreset_mode=AutoresetMode.DISABLED, 120 ) 121 122 evaluator = Evaluator( 123 env=eval_env, 124 action_fn=agent.evaluate_actions, 125 max_episode_steps=config.env.max_episode_steps, 126 ) 127 128 agent_state_vmap_axes = AgentState( 129 params=TD3NetworkParams( 130 critic_params=None, 131 actor_params=0, 132 target_critic_params=None, 133 target_actor_params=0, 134 ), 135 obs_preprocessor_state=None, 136 ) 137 138 workflow = cls( 139 env=env, 140 agent=agent, 141 agent_state_vmap_axes=agent_state_vmap_axes, 142 optimizer=optimizer, 143 ec_optimizer=ec_optimizer, 144 collector=collector, 145 evaluator=evaluator, 146 replay_buffer=replay_buffer, 147 config=config, 148 ) 149 150 return workflow 151 152 def _setup_agent_and_optimizer( 153 self, key: chex.PRNGKey 154 ) -> tuple[AgentState, chex.ArrayTree, ECState]: 155 agent_key, ec_key = jax.random.split(key) 156 157 # one actor + one critic 158 agent_state = self.agent.init( 159 self.env.obs_space, self.env.action_space, agent_key 160 ) 161 162 init_actor_params = agent_state.params.actor_params 163 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key) 164 165 agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params=None) 166 167 opt_state = PyTreeDict( 168 # Note: we create and drop the actors' opt_state at every step 169 critic=self.optimizer.init(agent_state.params.critic_params), 170 actor=None, 171 ) 172 173 return agent_state, opt_state, ec_opt_state 174 175 def _rl_injection(self, ec_opt_state: ECState, pop: Params) -> ECState: 176 return ec_opt_state.replace(pop=pop) 177
[docs] 178 def step(self, state: State) -> tuple[MetricBase, State]: 179 pop_size = self.config.pop_size 180 agent_state = state.agent_state 181 opt_state = state.opt_state 182 ec_opt_state = state.ec_opt_state 183 replay_buffer_state = state.replay_buffer_state 184 iterations = state.metrics.iterations + 1 185 186 key, perm_key, rollout_key, learn_key = jax.random.split(state.key, num=4) 187 188 # ======= CEM Sample ======== 189 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 190 191 # ======== RL update ======== 192 # learning_actor_indices = slice( 193 # self.config.num_learning_offspring 194 # ) # [:self.config.num_learning_offspring] 195 learning_actor_indices = jax.random.choice( 196 perm_key, 197 self.config.pop_size, 198 (self.config.num_learning_offspring,), 199 replace=False, 200 ) 201 202 def _rl_update(agent_state, opt_state, pop_actor_params): 203 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices) 204 learning_agent_state = cemrl_replace_td3_actor_params( 205 agent_state, learning_actor_params 206 ) 207 208 # reset actors' opt_state 209 learning_opt_state = opt_state.replace( 210 actor=self.optimizer.init(learning_actor_params), 211 ) 212 213 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update( 214 learning_agent_state, 215 learning_opt_state, 216 replay_buffer_state, 217 learn_key, 218 ) 219 220 pop_actor_params = tree_set( 221 pop_actor_params, 222 learning_agent_state.params.actor_params, 223 learning_actor_indices, 224 unique_indices=True, 225 ) 226 # drop the actors and their opt_state 227 agent_state = cemrl_replace_td3_actor_params( 228 learning_agent_state, pop_actor_params=None 229 ) 230 opt_state = learning_opt_state.replace(actor=None) 231 return td3_metrics, pop_actor_params, agent_state, opt_state 232 233 def _dummy_rl_update(agent_state, opt_state, pop_actor_params): 234 return ( 235 create_dummy_td3_trainmetric(self.config.num_learning_offspring), 236 pop_actor_params, 237 agent_state, 238 opt_state, 239 ) 240 241 td3_metrics, pop_actor_params, agent_state, opt_state = jax.lax.cond( 242 iterations > self.config.warmup_iters, 243 _rl_update, 244 _dummy_rl_update, 245 agent_state, 246 opt_state, 247 pop_actor_params, 248 ) 249 250 # ======== CEM update ======== 251 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params) 252 253 # the trajectory [T, #pop*B, ...] 254 # metrics: [#pop, B] 255 eval_metrics, trajectory, replay_buffer_state = self._rollout( 256 pop_agent_state, replay_buffer_state, rollout_key 257 ) 258 259 fitnesses = eval_metrics.episode_returns.mean(axis=-1) 260 261 ec_opt_state = self._rl_injection(ec_opt_state, pop_actor_params) 262 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 263 264 # adding debug info for CEM 265 ec_info = PyTreeDict(ec_metrics) 266 ec_info.cov_eps = ec_opt_state.cov_eps 267 268 def _calc_elites_stats(ec_info): 269 elites_indices = jax.lax.top_k(fitnesses, self.config.num_elites)[1] 270 elites_from_rl = jnp.isin(learning_actor_indices, elites_indices).astype( 271 jnp.int32 272 ) 273 ec_info.elites_from_rl = elites_from_rl.sum() 274 ec_info.elites_from_rl_ratio = elites_from_rl.mean() 275 return ec_info 276 277 def _dummy_calc_elites_stats(ec_info): 278 ec_info.elites_from_rl = jnp.zeros((), dtype=jnp.int32) 279 ec_info.elites_from_rl_ratio = jnp.zeros(()) 280 return ec_info 281 282 ec_info = jax.lax.cond( 283 iterations > self.config.warmup_iters, 284 _calc_elites_stats, 285 _dummy_calc_elites_stats, 286 ec_info, 287 ) 288 289 train_metrics = CEMRLTrainMetric( 290 rb_size=replay_buffer_state.buffer_size, 291 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1), 292 pop_episode_returns=eval_metrics.episode_returns.mean(-1), 293 rl_metrics=td3_metrics, 294 ec_info=ec_info, 295 ) 296 297 # calculate the number of timestep 298 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32) 299 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 300 301 workflow_metrics = state.metrics.replace( 302 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 303 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 304 iterations=iterations, 305 ) 306 307 state = state.replace( 308 key=key, 309 metrics=workflow_metrics, 310 agent_state=agent_state, 311 replay_buffer_state=replay_buffer_state, 312 ec_opt_state=ec_opt_state, 313 opt_state=opt_state, 314 ) 315 316 return train_metrics, state
317
[docs] 318 def evaluate(self, state: State) -> tuple[MetricBase, State]: 319 pop_mean_actor_params = state.ec_opt_state.mean 320 321 pop_mean_agent_state = cemrl_replace_td3_actor_params( 322 state.agent_state, pop_mean_actor_params 323 ) 324 325 key, eval_key = jax.random.split(state.key, num=2) 326 327 # [#episodes] 328 raw_eval_metrics = self.evaluator.evaluate( 329 pop_mean_agent_state, eval_key, num_episodes=self.config.eval_episodes 330 ) 331 332 eval_metrics = EvaluateMetric( 333 pop_center_episode_returns=raw_eval_metrics.episode_returns.mean(), 334 pop_center_episode_lengths=raw_eval_metrics.episode_lengths.mean(), 335 ) 336 337 state = state.replace(key=key) 338 339 return eval_metrics, state
340
[docs] 341 def learn(self, state: State) -> State: 342 num_iters = math.ceil( 343 (self.config.total_episodes - state.metrics.sampled_episodes) 344 / (self.config.episodes_for_fitness * self.config.pop_size) 345 ) 346 347 final_iteration = num_iters + state.metrics.iterations 348 for i in range(state.metrics.iterations, final_iteration): 349 iters = i + 1 350 train_metrics, state = self.step(state) 351 workflow_metrics = state.metrics 352 353 workflow_metrics_dict = workflow_metrics.to_local_dict() 354 self.recorder.write(workflow_metrics_dict, iters) 355 356 train_metrics_dict = train_metrics.to_local_dict() 357 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics( 358 train_metrics_dict["pop_episode_returns"], histogram=True 359 ) 360 361 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics( 362 train_metrics_dict["pop_episode_lengths"], histogram=True 363 ) 364 365 if train_metrics_dict["rl_metrics"] is not None: 366 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map( 367 get_1d_array_statistics, 368 train_metrics_dict["rl_metrics"]["raw_loss_dict"], 369 ) 370 371 self.recorder.write(train_metrics_dict, iters) 372 373 std_statistics = get_std_statistics(state.ec_opt_state.variance["params"]) 374 self.recorder.write({"ec/std": std_statistics}, iters) 375 376 if iters % self.config.eval_interval == 0 or iters == final_iteration: 377 eval_metrics, state = self.evaluate(state) 378 379 self.recorder.write( 380 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 381 ) 382 383 saved_state = state 384 if not self.config.save_replay_buffer: 385 saved_state = skip_replay_buffer_state(saved_state) 386 387 self.checkpoint_manager.save( 388 iters, 389 saved_state, 390 iters == final_iteration, 391 ) 392 393 return state
394 395
[docs] 396def get_std_statistics(variance): 397 def _get_stats(x): 398 x = np.sqrt(x) 399 return dict( 400 min=np.min(x).tolist(), 401 max=np.max(x).tolist(), 402 mean=np.mean(x).tolist(), 403 ) 404 405 return jtu.tree_map(_get_stats, variance)