Source code for evorl.algorithms.erl.cemrl_td3.cemrl_origin

  1import time
  2from omegaconf import DictConfig
  3
  4import chex
  5import jax
  6import jax.numpy as jnp
  7import jax.tree_util as jtu
  8
  9from evorl.metrics import MetricBase
 10from evorl.types import PyTreeDict, State
 11from evorl.utils.jax_utils import (
 12    tree_get,
 13    tree_set,
 14    scan_and_last,
 15    is_jitted,
 16)
 17from evorl.algorithms.td3 import TD3TrainMetric
 18
 19from ..cemrl_workflow import CEMRLTrainMetric as CEMRLTrainMetricBase
 20from .cemrl_td3_workflow import cemrl_replace_td3_actor_params
 21from .cemrl import CEMRLWorkflow as _CEMRLWorkflow
 22
 23
[docs] 24class CEMRLTrainMetric(CEMRLTrainMetricBase): 25 num_updates_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32) 26 time_cost_per_iter: float = 0.0
27 28
[docs] 29class WorkflowMetric(MetricBase): 30 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32) 31 sampled_timesteps_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32) 32 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 33 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
34 35
[docs] 36class CEMRLWorkflow(_CEMRLWorkflow): 37 """Original CEMRL impl. 38 39 1 critic + n actors + 1 replay buffer. 40 """ 41
[docs] 42 @classmethod 43 def name(cls): 44 return "CEMRL-Origin"
45 46 def _setup_workflow_metrics(self) -> MetricBase: 47 return WorkflowMetric() 48 49 @classmethod 50 def _build_from_config(cls, config: DictConfig): 51 workflow = super()._build_from_config(config) 52 53 def _rl_sample_and_update_fn(carry, unused_t): 54 key, agent_state, opt_state, replay_buffer_state = carry 55 56 def _sample_fn(key): 57 return workflow.replay_buffer.sample(replay_buffer_state, key) 58 59 key, rb_key, learn_key = jax.random.split(key, 3) 60 rb_keys = jax.random.split( 61 rb_key, 62 config.actor_update_interval * config.num_learning_offspring, 63 ) 64 sample_batches = jax.vmap(_sample_fn)(rb_keys) 65 66 # (actor_update_interval, num_learning_offspring, B, ...) 67 sample_batches = jtu.tree_map( 68 lambda x: x.reshape( 69 ( 70 config.actor_update_interval, 71 config.num_learning_offspring, 72 *x.shape[1:], 73 ) 74 ), 75 sample_batches, 76 ) 77 78 (agent_state, opt_state), train_info = workflow._rl_update_fn( 79 agent_state, opt_state, sample_batches, learn_key 80 ) 81 82 return (key, agent_state, opt_state, replay_buffer_state), train_info 83 84 if is_jitted(cls.evaluate): 85 _rl_sample_and_update_fn = jax.jit(_rl_sample_and_update_fn) 86 87 workflow._rl_sample_and_update_fn = _rl_sample_and_update_fn 88 89 return workflow 90 91 def _ec_sample(self, ec_opt_state): 92 return self.ec_optimizer.ask(ec_opt_state) 93 94 def _rl_update( 95 self, 96 agent_state, 97 opt_state, 98 replay_buffer_state, 99 key, 100 num_updates, 101 ): 102 """Add num_updates support. Therefore this method cannot be jitted.""" 103 ( 104 (_, agent_state, opt_state, replay_buffer_state), 105 train_info, 106 ) = scan_and_last( 107 self._rl_sample_and_update_fn, 108 (key, agent_state, opt_state, replay_buffer_state), 109 (), 110 length=num_updates, 111 ) 112 113 critic_loss, actor_loss, critic_loss_dict, actor_loss_dict = train_info 114 115 # smoothed td3 metrics 116 td3_metrics = TD3TrainMetric( 117 actor_loss=actor_loss, 118 critic_loss=critic_loss, 119 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 120 ) 121 122 td3_metrics = td3_metrics.replace( 123 actor_loss=td3_metrics.actor_loss / self.config.num_learning_offspring 124 ) 125 126 return td3_metrics, agent_state, opt_state 127 128 def _rollout_and_update( 129 self, pop_agent_state, replay_buffer_state, ec_opt_state, key 130 ): 131 """Calculate the fitness and update the replay buffer and ec_optimizer.""" 132 # the trajectory [T, #pop*B, ...] 133 # metrics: [#pop, B] 134 eval_metrics, trajectory, replay_buffer_state = self._rollout( 135 pop_agent_state, replay_buffer_state, key 136 ) 137 138 fitnesses = eval_metrics.episode_returns.mean(axis=-1) 139 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses) 140 141 return eval_metrics, ec_metrics, fitnesses, replay_buffer_state, ec_opt_state 142
[docs] 143 def step(self, state: State) -> tuple[MetricBase, State]: 144 start_t = time.perf_counter() 145 pop_size = self.config.pop_size 146 agent_state = state.agent_state 147 opt_state = state.opt_state 148 ec_opt_state = state.ec_opt_state 149 replay_buffer_state = state.replay_buffer_state 150 iterations = state.metrics.iterations + 1 151 152 key, perm_key, rollout_key, learn_key = jax.random.split(state.key, num=4) 153 154 # ======= CEM Sample ======== 155 pop_actor_params, ec_opt_state = self._ec_sample(ec_opt_state) 156 157 # ======== RL update ======== 158 159 if iterations > self.config.warmup_iters: 160 learning_actor_indices = jax.random.choice( 161 perm_key, 162 self.config.pop_size, 163 (self.config.num_learning_offspring,), 164 replace=False, 165 ) 166 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices) 167 learning_agent_state = cemrl_replace_td3_actor_params( 168 agent_state, learning_actor_params 169 ) 170 # reset and add actors' opt_state 171 learning_opt_state = opt_state.replace( 172 actor=self.optimizer.init(learning_actor_params), 173 ) 174 175 num_updates = ( 176 jnp.ceil( 177 state.metrics.sampled_timesteps_per_iter 178 * self.config.rl_updates_frac 179 ).astype(jnp.uint32) 180 // self.config.actor_update_interval 181 ) 182 183 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update( 184 learning_agent_state, 185 learning_opt_state, 186 replay_buffer_state, 187 learn_key, 188 num_updates, 189 ) 190 191 pop_actor_params = tree_set( 192 pop_actor_params, 193 learning_agent_state.params.actor_params, 194 learning_actor_indices, 195 unique_indices=True, 196 ) 197 198 # drop the actors and their opt_state 199 agent_state = cemrl_replace_td3_actor_params( 200 learning_agent_state, pop_actor_params=None 201 ) 202 opt_state = learning_opt_state.replace(actor=None) 203 204 # rl injection 205 ec_opt_state = self._rl_injection(ec_opt_state, pop_actor_params) 206 207 else: 208 num_updates = jnp.zeros((), dtype=jnp.uint32) 209 td3_metrics = None 210 211 # ======== CEM update ======== 212 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params) 213 eval_metrics, ec_metrics, fitnesses, replay_buffer_state, ec_opt_state = ( 214 self._rollout_and_update( 215 pop_agent_state, 216 replay_buffer_state, 217 ec_opt_state, 218 rollout_key, 219 ) 220 ) 221 222 # adding debug info for CEM 223 ec_info = PyTreeDict(ec_metrics) 224 ec_info.cov_eps = ec_opt_state.cov_eps 225 if td3_metrics is not None: 226 elites_indices = jax.lax.top_k(fitnesses, self.config.num_elites)[1] 227 elites_from_rl = jnp.isin(learning_actor_indices, elites_indices) 228 ec_info.elites_from_rl = elites_from_rl.sum() 229 ec_info.elites_from_rl_ratio = elites_from_rl.mean() 230 231 train_metrics = CEMRLTrainMetric( 232 rb_size=replay_buffer_state.buffer_size, 233 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1), 234 pop_episode_returns=eval_metrics.episode_returns.mean(-1), 235 rl_metrics=td3_metrics, 236 ec_info=ec_info, 237 num_updates_per_iter=num_updates, 238 time_cost_per_iter=time.perf_counter() - start_t, 239 ) 240 241 # calculate the number of timestep 242 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32) 243 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size) 244 245 workflow_metrics = state.metrics.replace( 246 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 247 sampled_timesteps_per_iter=sampled_timesteps, 248 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 249 iterations=iterations, 250 ) 251 252 state = state.replace( 253 key=key, 254 metrics=workflow_metrics, 255 agent_state=agent_state, 256 replay_buffer_state=replay_buffer_state, 257 ec_opt_state=ec_opt_state, 258 opt_state=opt_state, 259 ) 260 261 return train_metrics, state
262
[docs] 263 @classmethod 264 def enable_jit(cls) -> None: 265 cls._ec_sample = jax.jit(cls._ec_sample, static_argnums=(0,)) 266 cls._rollout_and_update = jax.jit(cls._rollout_and_update, static_argnums=(0,)) 267 268 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 269 cls._postsetup_replaybuffer = jax.jit( 270 cls._postsetup_replaybuffer, static_argnums=(0,) 271 )