Source code for evorl.algorithms.meta.pbt_sac.param_sac

  1import logging
  2from omegaconf import DictConfig
  3
  4import chex
  5import jax
  6import jax.numpy as jnp
  7import optax
  8
  9from evorl.agent import AgentState
 10from evorl.envs import Space, Box, create_env, AutoresetMode
 11from evorl.evaluators import Evaluator
 12from evorl.distributed import agent_gradient_update, psum
 13from evorl.distribution import get_tanh_norm_dist
 14from evorl.metrics import MetricBase
 15from evorl.rollout import rollout
 16from evorl.replay_buffers import ReplayBuffer
 17from evorl.types import PyTreeDict, State, LossDict
 18from evorl.sample_batch import SampleBatch
 19from evorl.networks import make_policy_network, make_q_network
 20from evorl.utils import running_statistics
 21from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient
 22from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
 23
 24from evorl.algorithms.offpolicy_utils import clean_trajectory, OffPolicyWorkflowTemplate
 25from evorl.algorithms.sac import SACTrainMetric, SACAgent
 26
 27
 28logger = logging.getLogger(__name__)
 29
 30
[docs] 31class ParamSACTrainMetric(SACTrainMetric): 32 trajectory: SampleBatch = None
33 34
[docs] 35class ParamSACAgent(SACAgent): 36 """SAC agent with parameterized hyperparameters.""" 37
[docs] 38 def init( 39 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 40 ) -> AgentState: 41 agent_state = super().init(obs_space, action_space, key) 42 43 return agent_state.replace( 44 extra_state=agent_state.extra_state.replace( 45 discount_g=-jnp.log( 46 1 - jnp.float32(self.discount) 47 ), # discount = 1 - exp(-g) 48 ) 49 )
50
[docs] 51 def critic_loss( 52 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 53 ) -> LossDict: 54 obs = sample_batch.obs 55 next_obs = sample_batch.extras.env_extras.ori_obs 56 if self.normalize_obs: 57 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 58 next_obs = self.obs_preprocessor( 59 next_obs, agent_state.obs_preprocessor_state 60 ) 61 62 discounts = (1 - jnp.exp(-agent_state.extra_state.discount_g)) * ( 63 1 - sample_batch.extras.env_extras.termination 64 ) 65 66 alpha = jnp.exp(agent_state.params.log_alpha) 67 68 # [B, 2] 69 qs = self.critic_network.apply( 70 agent_state.params.critic_params, obs, sample_batch.actions 71 ) 72 73 next_raw_actions = self.actor_network.apply( 74 agent_state.params.actor_params, next_obs 75 ) 76 next_actions_dist = get_tanh_norm_dist(*jnp.split(next_raw_actions, 2, axis=-1)) 77 next_actions = next_actions_dist.sample(seed=key) 78 next_actions_logp = next_actions_dist.log_prob(next_actions) 79 # [B, 2] 80 next_qs = self.critic_network.apply( 81 agent_state.params.target_critic_params, next_obs, next_actions 82 ) 83 qs_target = sample_batch.rewards + discounts * ( 84 jnp.min(next_qs, axis=-1) - alpha * next_actions_logp 85 ) 86 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2)) 87 88 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean() 89 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
90 91
[docs] 92def make_mlp_sac_agent( 93 action_space: Space, 94 critic_hidden_layer_sizes: tuple[int] = (256, 256), 95 actor_hidden_layer_sizes: tuple[int] = (256, 256), 96 init_alpha: float = 1.0, 97 discount: float = 0.99, 98 normalize_obs: bool = False, 99): 100 if isinstance(action_space, Box): 101 action_size = action_space.shape[0] * 2 102 else: 103 raise NotImplementedError(f"Unsupported action space: {action_space}") 104 105 actor_network = make_policy_network( 106 action_size=action_size, # mean+std 107 hidden_layer_sizes=actor_hidden_layer_sizes, 108 ) 109 110 critic_network = make_q_network( 111 n_stack=2, 112 hidden_layer_sizes=critic_hidden_layer_sizes, 113 ) 114 115 if normalize_obs: 116 obs_preprocessor = running_statistics.normalize 117 else: 118 obs_preprocessor = None 119 120 return ParamSACAgent( 121 critic_network=critic_network, 122 actor_network=actor_network, 123 obs_preprocessor=obs_preprocessor, 124 init_alpha=init_alpha, 125 discount=discount, 126 )
127 128
[docs] 129class ParamSACWorkflow(OffPolicyWorkflowTemplate): 130 """Workflow for ParamSAC. 131 132 Note: This workflow can only work with PBTParamSACWorkflow, since the replay buffer is initialized and managed by PBT externally. 133 """ 134
[docs] 135 @classmethod 136 def name(cls): 137 return "ParamSAC"
138 139 @classmethod 140 def _build_from_config(cls, config: DictConfig): 141 env = create_env( 142 config.env, 143 episode_length=config.env.max_episode_steps, 144 parallel=config.num_envs, 145 autoreset_mode=AutoresetMode.NORMAL, 146 record_ori_obs=True, 147 ) 148 149 agent = make_mlp_sac_agent( 150 action_space=env.action_space, 151 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 152 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 153 init_alpha=config.alpha, 154 discount=config.discount, 155 normalize_obs=config.normalize_obs, 156 ) 157 158 # TODO: use different lr for critic and actor 159 if ( 160 config.optimizer.grad_clip_norm is not None 161 and config.optimizer.grad_clip_norm > 0 162 ): 163 optimizer = optax.chain( 164 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 165 optax.adam(config.optimizer.lr), 166 ) 167 else: 168 optimizer = optax.adam(config.optimizer.lr) 169 170 replay_buffer = ReplayBuffer( 171 capacity=config.replay_buffer_capacity, 172 min_sample_timesteps=config.batch_size, 173 sample_batch_size=config.batch_size, 174 ) 175 176 eval_env = create_env( 177 config.env, 178 episode_length=config.env.max_episode_steps, 179 parallel=config.num_eval_envs, 180 autoreset_mode=AutoresetMode.DISABLED, 181 ) 182 183 evaluator = Evaluator( 184 env=eval_env, 185 action_fn=agent.evaluate_actions, 186 max_episode_steps=config.env.max_episode_steps, 187 ) 188 189 return cls( 190 env, 191 agent, 192 optimizer, 193 evaluator, 194 replay_buffer, 195 config, 196 ) 197 198 def _setup_agent_and_optimizer( 199 self, key: chex.PRNGKey 200 ) -> tuple[AgentState, chex.ArrayTree]: 201 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key) 202 opt_state = PyTreeDict( 203 dict( 204 actor=self.optimizer.init(agent_state.params.actor_params), 205 critic=self.optimizer.init(agent_state.params.critic_params), 206 ) 207 ) 208 209 return agent_state, opt_state 210
[docs] 211 def setup(self, key: chex.PRNGKey) -> State: 212 key, agent_key, env_key = jax.random.split(key, 3) 213 214 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key) 215 workflow_metrics = self._setup_workflow_metrics() 216 env_state = self.env.reset(env_key) 217 218 state = State( 219 key=key, 220 metrics=workflow_metrics, 221 agent_state=agent_state, 222 env_state=env_state, 223 opt_state=opt_state, 224 replay_buffer_state=None, # init externally 225 hp_state=PyTreeDict( 226 actor_loss_weight=jnp.float32(1.0), 227 critic_loss_weight=jnp.float32(1.0), 228 ), 229 ) 230 231 return state.replace()
232
[docs] 233 def step(self, state: State) -> tuple[MetricBase, State]: 234 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 235 236 # the trajectory [T, B, ...] 237 trajectory, env_state = rollout( 238 env_fn=self.env.step, 239 action_fn=self.agent.compute_actions, 240 env_state=state.env_state, 241 agent_state=state.agent_state, 242 key=rollout_key, 243 rollout_length=self.config.rollout_length, 244 env_extra_fields=("ori_obs", "termination"), 245 ) 246 247 trajectory_dones = trajectory.dones 248 trajectory = clean_trajectory(trajectory) 249 trajectory = flatten_rollout_trajectory(trajectory) 250 trajectory = tree_stop_gradient(trajectory) 251 252 agent_state = state.agent_state 253 if agent_state.obs_preprocessor_state is not None: 254 agent_state = agent_state.replace( 255 obs_preprocessor_state=running_statistics.update( 256 agent_state.obs_preprocessor_state, 257 trajectory.obs, 258 dp_axis_name=self.dp_axis_name, 259 ) 260 ) 261 262 # Here replay_buffer_state is read-only, 263 # we save the data externally instead 264 replay_buffer_state = state.replay_buffer_state 265 266 def critic_loss_fn(agent_state, sample_batch, key): 267 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key) 268 269 loss = loss_dict.critic_loss * state.hp_state.critic_loss_weight 270 return loss, loss_dict 271 272 def actor_loss_fn(agent_state, sample_batch, key): 273 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key) 274 275 loss = loss_dict.actor_loss * state.hp_state.actor_loss_weight 276 return loss, loss_dict 277 278 critic_update_fn = agent_gradient_update( 279 critic_loss_fn, 280 self.optimizer, 281 dp_axis_name=self.dp_axis_name, 282 has_aux=True, 283 attach_fn=lambda agent_state, critic_params: agent_state.replace( 284 params=agent_state.params.replace(critic_params=critic_params) 285 ), 286 detach_fn=lambda agent_state: agent_state.params.critic_params, 287 ) 288 289 actor_update_fn = agent_gradient_update( 290 actor_loss_fn, 291 self.optimizer, 292 dp_axis_name=self.dp_axis_name, 293 has_aux=True, 294 attach_fn=lambda agent_state, actor_params: agent_state.replace( 295 params=agent_state.params.replace(actor_params=actor_params) 296 ), 297 detach_fn=lambda agent_state: agent_state.params.actor_params, 298 ) 299 300 def _sample_and_update_fn(carry, unused_t): 301 key, agent_state, opt_state = carry 302 303 critic_opt_state = opt_state.critic 304 actor_opt_state = opt_state.actor 305 306 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4) 307 308 if self.config.actor_update_interval - 1 > 0: 309 310 def _sample_and_update_critic_fn(carry, unused_t): 311 key, agent_state, critic_opt_state = carry 312 313 key, rb_key, critic_key = jax.random.split(key, num=3) 314 # it's safe to use read-only replay_buffer_state here. 315 sample_batch = self.replay_buffer.sample( 316 replay_buffer_state, rb_key 317 ) 318 319 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 320 critic_update_fn( 321 critic_opt_state, agent_state, sample_batch, critic_key 322 ) 323 ) 324 325 return (key, agent_state, critic_opt_state), None 326 327 key, critic_multiple_update_key = jax.random.split(key) 328 329 (_, agent_state, critic_opt_state), _ = jax.lax.scan( 330 _sample_and_update_critic_fn, 331 (critic_multiple_update_key, agent_state, critic_opt_state), 332 (), 333 length=self.config.actor_update_interval - 1, 334 ) 335 336 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key) 337 338 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = ( 339 critic_update_fn( 340 critic_opt_state, agent_state, sample_batch, critic_key 341 ) 342 ) 343 344 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = ( 345 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key) 346 ) 347 348 opt_state = opt_state.replace( 349 actor=actor_opt_state, critic=critic_opt_state 350 ) 351 352 res = ( 353 critic_loss, 354 actor_loss, 355 critic_loss_dict, 356 actor_loss_dict, 357 ) 358 359 target_critic_params = soft_target_update( 360 agent_state.params.target_critic_params, 361 agent_state.params.critic_params, 362 self.config.tau, 363 ) 364 agent_state = agent_state.replace( 365 params=agent_state.params.replace( 366 target_critic_params=target_critic_params 367 ) 368 ) 369 370 return (key, agent_state, opt_state), res 371 372 ( 373 (_, agent_state, opt_state), 374 ( 375 critic_loss, 376 actor_loss, 377 critic_loss_dict, 378 actor_loss_dict, 379 ), 380 ) = scan_and_mean( 381 _sample_and_update_fn, 382 (learn_key, agent_state, state.opt_state), 383 (), 384 length=self.config.num_updates_per_iter, 385 ) 386 387 train_metrics = ParamSACTrainMetric( 388 actor_loss=actor_loss, 389 critic_loss=critic_loss, 390 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}), 391 trajectory=trajectory, 392 ).all_reduce(dp_axis_name=self.dp_axis_name) 393 394 # calculate the number of timestep 395 sampled_timesteps = psum( 396 jnp.uint32(self.config.rollout_length * self.config.num_envs), 397 axis_name=self.dp_axis_name, 398 ) 399 sampled_epsiodes = psum( 400 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 401 ) 402 403 # iterations is the number of updates of the agent 404 workflow_metrics = state.metrics.replace( 405 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 406 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 407 iterations=state.metrics.iterations + 1, 408 ).all_reduce(dp_axis_name=self.dp_axis_name) 409 410 return train_metrics, state.replace( 411 key=key, 412 metrics=workflow_metrics, 413 agent_state=agent_state, 414 env_state=env_state, 415 # replay_buffer_state=replay_buffer_state, 416 opt_state=opt_state, 417 )