Source code for evorl.algorithms.erl.cemrl_td3.cemrl_td3_workflow

  1import jax
  2import jax.numpy as jnp
  3import jax.tree_util as jtu
  4import optax
  5
  6from omegaconf import DictConfig
  7from evorl.agent import Agent, AgentState
  8
  9from evorl.distributed import agent_gradient_update
 10from evorl.types import PyTreeDict
 11from evorl.utils.jax_utils import (
 12    right_shift_with_padding,
 13    tree_stop_gradient,
 14    scan_and_mean,
 15)
 16from evorl.utils.rl_toolkits import (
 17    flatten_rollout_trajectory,
 18    flatten_pop_rollout_episode,
 19    soft_target_update,
 20)
 21from evorl.algorithms.td3 import TD3TrainMetric, TD3NetworkParams
 22
 23from ..cemrl_workflow import CEMRLWorkflowBase
 24
 25
[docs] 26class CEMRLTD3WorkflowTemplate(CEMRLWorkflowBase): 27 """A template for ERL workflow on TD3 Agent.""" 28 29 def __init__(self, **kwargs): 30 super().__init__(**kwargs) 31 self._rl_update_fn = build_cemrl_rl_update_fn( 32 self.agent, 33 self.optimizer, 34 self.config, 35 self.agent_state_vmap_axes, 36 ) 37 38 def _rollout(self, pop_agent_state, replay_buffer_state, key): 39 return rollout_episode( 40 pop_agent_state, 41 replay_buffer_state, 42 key, 43 collector=self.collector, 44 replay_buffer=self.replay_buffer, 45 agent_state_vmap_axes=self.agent_state_vmap_axes, 46 num_episodes=self.config.episodes_for_fitness, 47 num_agents=self.config.pop_size, 48 ) 49 50 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key): 51 def _sample_fn(key): 52 return self.replay_buffer.sample(replay_buffer_state, key) 53 54 def _sample_and_update_fn(carry, unused_t): 55 key, agent_state, opt_state = carry 56 57 key, rb_key, learn_key = jax.random.split(key, 3) 58 59 rb_keys = jax.random.split( 60 rb_key, 61 self.config.actor_update_interval * self.config.num_learning_offspring, 62 ) 63 sample_batches = jax.vmap(_sample_fn)(rb_keys) 64 65 # (actor_update_interval, num_learning_offspring, B, ...) 66 sample_batches = jtu.tree_map( 67 lambda x: x.reshape( 68 ( 69 self.config.actor_update_interval, 70 self.config.num_learning_offspring, 71 *x.shape[1:], 72 ) 73 ), 74 sample_batches, 75 ) 76 77 (agent_state, opt_state), train_info = self._rl_update_fn( 78 agent_state, opt_state, sample_batches, learn_key 79 ) 80 81 return (key, agent_state, opt_state), train_info 82 83 ( 84 (_, agent_state, opt_state), 85 ( 86 critic_loss, 87 actor_loss, 88 critic_loss_dict, 89 actor_loss_dict, 90 ), 91 ) = scan_and_mean( 92 _sample_and_update_fn, 93 (key, agent_state, opt_state), 94 (), 95 length=self.config.num_rl_updates_per_iter, 96 ) 97 98 # smoothed td3 metrics 99 td3_metrics = TD3TrainMetric( 100 actor_loss=actor_loss, 101 critic_loss=critic_loss, 102 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 103 ) 104 105 td3_metrics = td3_metrics.replace( 106 actor_loss=td3_metrics.actor_loss / self.config.num_learning_offspring 107 ) 108 109 return td3_metrics, agent_state, opt_state
110 111
[docs] 112def cemrl_replace_td3_actor_params( 113 agent_state: AgentState, pop_actor_params: TD3NetworkParams 114) -> AgentState: 115 # Keep the critic params unchanged. 116 117 return agent_state.replace( 118 params=agent_state.params.replace( 119 actor_params=pop_actor_params, 120 target_actor_params=pop_actor_params, 121 ) 122 )
123 124 125DUMMY_TD3_TRAINMETRIC = TD3TrainMetric( 126 critic_loss=jnp.zeros(()), 127 actor_loss=jnp.zeros(()), 128 raw_loss_dict=PyTreeDict( 129 critic_loss=jnp.zeros(()), 130 q_value=jnp.zeros(()), 131 actor_loss=jnp.zeros(()), 132 ), 133) 134 135
[docs] 136def create_dummy_td3_trainmetric(num: int) -> TD3TrainMetric: 137 if num >= 1: 138 return DUMMY_TD3_TRAINMETRIC.replace( 139 raw_loss_dict=jtu.tree_map( 140 lambda x: jnp.broadcast_to(x, (num, *x.shape)), 141 DUMMY_TD3_TRAINMETRIC.raw_loss_dict, 142 ) 143 ) 144 else: 145 raise ValueError(f"num should be positive, got {num}")
146 147
[docs] 148def rollout_episode( 149 agent_state: AgentState, 150 replay_buffer_state, 151 key, 152 *, 153 collector, 154 replay_buffer, 155 agent_state_vmap_axes, 156 num_episodes, 157 num_agents, 158): 159 eval_metrics, trajectory = jax.vmap( 160 collector.rollout, 161 in_axes=(agent_state_vmap_axes, 0, None), 162 )( 163 agent_state, 164 jax.random.split(key, num_agents), 165 num_episodes, 166 ) 167 168 # [n, T, B, ...] -> [T, n*B, ...] 169 trajectory = trajectory.replace(next_obs=None) 170 trajectory = flatten_pop_rollout_episode(trajectory) 171 172 mask = jnp.logical_not(right_shift_with_padding(trajectory.dones, 1)) 173 trajectory = trajectory.replace(dones=None) 174 trajectory, mask = tree_stop_gradient( 175 flatten_rollout_trajectory((trajectory, mask)) 176 ) 177 replay_buffer_state = replay_buffer.add(replay_buffer_state, trajectory, mask) 178 179 return eval_metrics, trajectory, replay_buffer_state
180 181
[docs] 182def build_cemrl_rl_update_fn( 183 agent: Agent, 184 optimizer: optax.GradientTransformation, 185 config: DictConfig, 186 agent_state_vmap_axes: AgentState, 187): 188 """K actors + 1 shared critic.""" 189 num_learning_offspring = config.num_learning_offspring 190 191 def critic_loss_fn(agent_state, sample_batch, key): 192 # loss on a single critic with multiple actors 193 # sample_batch: (n, B, ...) 194 195 loss_dict = jax.vmap(agent.critic_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 196 agent_state, sample_batch, jax.random.split(key, num_learning_offspring) 197 ) 198 199 # mean over the num_learning_offspring 200 loss = loss_dict.critic_loss.mean() 201 202 return loss, loss_dict 203 204 def actor_loss_fn(agent_state, sample_batch, key): 205 # loss on a single actor 206 207 loss_dict = jax.vmap(agent.actor_loss, in_axes=(agent_state_vmap_axes, 0, 0))( 208 agent_state, sample_batch, jax.random.split(key, num_learning_offspring) 209 ) 210 211 # sum over the num_learning_offspring 212 loss = loss_dict.actor_loss.sum() 213 214 return loss, loss_dict 215 216 critic_update_fn = agent_gradient_update( 217 critic_loss_fn, 218 optimizer, 219 has_aux=True, 220 attach_fn=lambda agent_state, critic_params: agent_state.replace( 221 params=agent_state.params.replace(critic_params=critic_params) 222 ), 223 detach_fn=lambda agent_state: agent_state.params.critic_params, 224 ) 225 226 actor_update_fn = agent_gradient_update( 227 actor_loss_fn, 228 optimizer, 229 has_aux=True, 230 attach_fn=lambda agent_state, actor_params: agent_state.replace( 231 params=agent_state.params.replace(actor_params=actor_params) 232 ), 233 detach_fn=lambda agent_state: agent_state.params.actor_params, 234 ) 235 236 def _update_fn(agent_state, opt_state, sample_batches, key): 237 critic_opt_state = opt_state.critic 238 actor_opt_state = opt_state.actor 239 240 key, critic_key, actor_key = jax.random.split(key, num=3) 241 242 critic_sample_batches = jtu.tree_map(lambda x: x[:-1], sample_batches) 243 last_sample_batch = jtu.tree_map(lambda x: x[-1], sample_batches) 244 245 if config.actor_update_interval - 1 > 0: 246 247 def _update_critic_fn(carry, sample_batch): 248 key, agent_state, critic_opt_state = carry 249 250 key, critic_key = jax.random.split(key) 251 252 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 253 critic_update_fn( 254 critic_opt_state, agent_state, sample_batch, critic_key 255 ) 256 ) 257 258 return (key, agent_state, critic_opt_state), None 259 260 key, critic_multiple_update_key = jax.random.split(key) 261 262 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 263 _update_critic_fn, 264 ( 265 critic_multiple_update_key, 266 agent_state, 267 critic_opt_state, 268 ), 269 critic_sample_batches, 270 length=config.actor_update_interval - 1, 271 ) 272 273 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 274 critic_update_fn( 275 critic_opt_state, agent_state, last_sample_batch, critic_key 276 ) 277 ) 278 279 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = actor_update_fn( 280 actor_opt_state, agent_state, last_sample_batch, actor_key 281 ) 282 283 # not need vmap 284 target_actor_params = soft_target_update( 285 agent_state.params.target_actor_params, 286 agent_state.params.actor_params, 287 config.tau, 288 ) 289 target_critic_params = soft_target_update( 290 agent_state.params.target_critic_params, 291 agent_state.params.critic_params, 292 config.tau, 293 ) 294 agent_state = agent_state.replace( 295 params=agent_state.params.replace( 296 target_actor_params=target_actor_params, 297 target_critic_params=target_critic_params, 298 ) 299 ) 300 301 opt_state = opt_state.replace(actor=actor_opt_state, critic=critic_opt_state) 302 303 return ( 304 (agent_state, opt_state), 305 (critic_loss, actor_loss, critic_loss_dict, actor_loss_dict), 306 ) 307 308 return _update_fn