Source code for evorl.algorithms.erl.erl_td3.erl_origin

  1import logging
  2import time
  3from omegaconf import DictConfig
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8import jax.tree_util as jtu
  9
 10from evorl.metrics import MetricBase
 11from evorl.types import PyTreeDict, State
 12from evorl.utils.jax_utils import is_jitted
 13from evorl.algorithms.td3 import TD3TrainMetric
 14
 15from ..erl_workflow import ERLTrainMetric as ERLTrainMetricBase
 16from .erl_td3_workflow import create_dummy_td3_trainmetric, erl_replace_td3_actor_params
 17from .erl_ga import ERLGAWorkflow
 18
 19
 20logger = logging.getLogger(__name__)
 21
 22
[docs] 23class ERLTrainMetric(ERLTrainMetricBase): 24 num_updates_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32) 25 time_cost_per_iter: float = 0.0
26 27
[docs] 28class ERLWorkflow(ERLGAWorkflow): 29 """Original ERL impl. 30 31 Have dynamic training updates per iteration, i.e., #rl_updates = #sampled_timesteps_this_iter. Therefore the `step()` function cannot be directly jitted. 32 """ 33
[docs] 34 @classmethod 35 def name(cls): 36 return "ERL-Origin"
37 38 @classmethod 39 def _build_from_config(cls, config: DictConfig): 40 workflow = super()._build_from_config(config) 41 42 def _rl_sample_and_update_fn(carry, unused_t): 43 key, agent_state, opt_state, replay_buffer_state, _ = carry 44 45 def _sample_fn(key): 46 return workflow.replay_buffer.sample(replay_buffer_state, key) 47 48 key, rb_key, learn_key = jax.random.split(key, 3) 49 50 rb_keys = jax.random.split( 51 rb_key, config.actor_update_interval * config.num_rl_agents 52 ) 53 sample_batches = jax.vmap(_sample_fn)(rb_keys) 54 55 # (actor_update_interval, num_learning_offspring, B, ...) 56 sample_batches = jtu.tree_map( 57 lambda x: x.reshape( 58 ( 59 config.actor_update_interval, 60 config.num_rl_agents, 61 *x.shape[1:], 62 ) 63 ), 64 sample_batches, 65 ) 66 67 ( 68 (agent_state, opt_state), 69 ( 70 critic_loss, 71 actor_loss, 72 critic_loss_dict, 73 actor_loss_dict, 74 ), 75 ) = workflow._rl_update_fn( 76 agent_state, opt_state, sample_batches, learn_key 77 ) 78 79 td3_metrics = TD3TrainMetric( 80 actor_loss=actor_loss, 81 critic_loss=critic_loss, 82 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 83 ) 84 85 # Note: we do not put train_info into y_t for saving memory 86 return (key, agent_state, opt_state, replay_buffer_state, td3_metrics), None 87 88 if is_jitted(cls.evaluate): 89 _rl_sample_and_update_fn = jax.jit(_rl_sample_and_update_fn) 90 91 workflow._rl_sample_and_update_fn = _rl_sample_and_update_fn 92 93 return workflow 94 95 def _ec_update(self, ec_opt_state, fitnesses): 96 return self.ec_optimizer.tell(ec_opt_state, fitnesses) 97 98 def _ec_update_with_rl_injection(self, ec_opt_state, agent_state, fitnesses): 99 ec_opt_state = self._rl_injection(ec_opt_state, agent_state) 100 return self.ec_optimizer.tell_external(ec_opt_state, fitnesses) 101 102 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key, num_updates): 103 # unlike erl-ga, since num_updates is large, we only use the last train_info 104 init_td3_metrics = create_dummy_td3_trainmetric(self.config.num_rl_agents) 105 106 (_, agent_state, opt_state, replay_buffer_state, td3_metrics), _ = jax.lax.scan( 107 self._rl_sample_and_update_fn, 108 (key, agent_state, opt_state, replay_buffer_state, init_td3_metrics), 109 (), 110 length=num_updates, 111 ) 112 113 return td3_metrics, agent_state, opt_state 114
[docs] 115 def step(self, state: State) -> tuple[MetricBase, State]: 116 """The basic step function for the workflow to update agent.""" 117 start_t = time.perf_counter() 118 pop_size = self.config.pop_size 119 agent_state = state.agent_state 120 opt_state = state.opt_state 121 ec_opt_state = state.ec_opt_state 122 replay_buffer_state = state.replay_buffer_state 123 iterations = state.metrics.iterations + 1 124 125 sampled_timesteps = jnp.zeros((), dtype=jnp.uint32) 126 sampled_episodes = jnp.zeros((), dtype=jnp.uint32) 127 128 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split( 129 state.key, num=4 130 ) 131 132 # ======== EC rollout ======== 133 # the trajectory [#pop, T, B, ...] 134 # metrics: [#pop, B] 135 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state) 136 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params) 137 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout( 138 pop_agent_state, replay_buffer_state, ec_rollout_key 139 ) 140 141 # calculate the number of timestep 142 sampled_timesteps += ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 143 sampled_episodes += jnp.uint32(self.config.episodes_for_fitness * pop_size) 144 145 train_metrics = ERLTrainMetric( 146 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1), 147 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1), 148 ) 149 150 # ======== RL update ======== 151 152 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout( 153 agent_state, replay_buffer_state, rl_rollout_key 154 ) 155 156 if self.config.rl_updates_mode == "global": # same as original ERL 157 total_timesteps = state.metrics.sampled_timesteps + sampled_timesteps 158 num_updates = ( 159 jnp.ceil(total_timesteps * self.config.rl_updates_frac).astype( 160 jnp.uint32 161 ) 162 // self.config.actor_update_interval 163 ) 164 elif self.config.rl_updates_mode == "iter": 165 num_updates = ( 166 jnp.ceil(sampled_timesteps * self.config.rl_updates_frac).astype( 167 jnp.uint32 168 ) 169 // self.config.actor_update_interval 170 ) 171 else: 172 raise ValueError(f"Unknown rl_updates_mode: {self.config.rl_updates_mode}") 173 174 td3_metrics, agent_state, opt_state = self._rl_update( 175 agent_state, opt_state, replay_buffer_state, learn_key, num_updates 176 ) 177 178 # get average loss 179 td3_metrics = td3_metrics.replace( 180 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents, 181 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents, 182 ) 183 184 train_metrics = train_metrics.replace( 185 num_updates_per_iter=num_updates, 186 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1), 187 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1), 188 rl_metrics=td3_metrics, 189 ) 190 191 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32) 192 sampled_timesteps += rl_sampled_timesteps 193 sampled_episodes += jnp.uint32( 194 self.config.num_rl_agents * self.config.rollout_episodes 195 ) 196 197 # ======== EC update ======== 198 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1) 199 200 if iterations % self.config.rl_injection_interval == 0: 201 ec_metrics, ec_opt_state = self._ec_update_with_rl_injection( 202 ec_opt_state, agent_state, fitnesses 203 ) 204 else: 205 ec_metrics, ec_opt_state = self._ec_update(ec_opt_state, fitnesses) 206 207 train_metrics = train_metrics.replace( 208 ec_info=ec_metrics, 209 rb_size=replay_buffer_state.buffer_size, 210 time_cost_per_iter=time.perf_counter() - start_t, 211 ) 212 213 # iterations is the number of updates of the agent 214 workflow_metrics = state.metrics.replace( 215 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 216 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes, 217 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps 218 + rl_sampled_timesteps, 219 iterations=iterations, 220 ) 221 222 state = state.replace( 223 key=key, 224 metrics=workflow_metrics, 225 agent_state=agent_state, 226 replay_buffer_state=replay_buffer_state, 227 ec_opt_state=ec_opt_state, 228 opt_state=opt_state, 229 ) 230 231 return train_metrics, state
232
[docs] 233 @classmethod 234 def enable_jit(cls) -> None: 235 # Do not jit replay buffer add 236 237 cls._rl_rollout = jax.jit(cls._rl_rollout, static_argnums=(0,)) 238 cls._ec_rollout = jax.jit(cls._ec_rollout, static_argnums=(0,)) 239 cls._ec_update = jax.jit(cls._ec_update, static_argnums=(0,)) 240 cls._ec_update_with_rl_injection = jax.jit( 241 cls._ec_update_with_rl_injection, static_argnums=(0,) 242 ) 243 244 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,)) 245 cls._postsetup_replaybuffer = jax.jit( 246 cls._postsetup_replaybuffer, static_argnums=(0,) 247 )
248 249
[docs] 250def get_ec_pop_statistics(pop): 251 pop = pop["params"] 252 253 def _get_stats(x): 254 return dict( 255 min=jnp.min(x).tolist(), 256 max=jnp.max(x).tolist(), 257 ) 258 259 return jtu.tree_map(_get_stats, pop)