Source code for evorl.algorithms.a2c

  1import logging
  2import math
  3from typing import Any
  4
  5import chex
  6import flax.linen as nn
  7import jax
  8import jax.numpy as jnp
  9import jax.tree_util as jtu
 10import optax
 11from omegaconf import DictConfig
 12
 13from evorl.distributed import agent_gradient_update, psum
 14from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
 15from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
 16from evorl.evaluators import Evaluator
 17from evorl.metrics import TrainMetric, MetricBase
 18from evorl.networks import make_policy_network, make_v_network
 19from evorl.rollout import rollout
 20from evorl.sample_batch import SampleBatch
 21from evorl.types import (
 22    MISSING_REWARD,
 23    Action,
 24    LossDict,
 25    Params,
 26    PolicyExtraInfo,
 27    PyTreeData,
 28    PyTreeDict,
 29    State,
 30    pytree_field,
 31)
 32from evorl.utils import running_statistics
 33from evorl.utils.jax_utils import tree_get, tree_stop_gradient
 34from evorl.utils.rl_toolkits import (
 35    average_episode_discount_return,
 36    compute_gae,
 37    flatten_rollout_trajectory,
 38)
 39from evorl.workflows import OnPolicyWorkflow
 40from evorl.recorders import add_prefix
 41
 42
 43from evorl.agent import Agent, AgentState
 44
 45logger = logging.getLogger(__name__)
 46
 47
[docs] 48class A2CNetworkParams(PyTreeData): 49 """Contains training state for the learner.""" 50 51 policy_params: Params 52 value_params: Params
53 54
[docs] 55class A2CAgent(Agent): 56 continuous_action: bool 57 policy_network: nn.Module # nn.Module is ok 58 value_network: nn.Module 59 obs_preprocessor: Any = pytree_field(default=None, static=True) 60 61 @property 62 def normalize_obs(self): 63 return self.obs_preprocessor is not None 64
[docs] 65 def init( 66 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 67 ) -> AgentState: 68 policy_key, value_key = jax.random.split(key, 2) 69 70 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key)) 71 72 policy_params = self.policy_network.init(policy_key, dummy_obs) 73 74 value_params = self.value_network.init(value_key, dummy_obs) 75 76 params_state = A2CNetworkParams( 77 policy_params=policy_params, value_params=value_params 78 ) 79 80 if self.normalize_obs: 81 # Note: statistics are broadcasted to [T*B] 82 obs_preprocessor_state = running_statistics.init_state( 83 tree_get(dummy_obs, 0) 84 ) 85 else: 86 obs_preprocessor_state = None 87 88 return AgentState( 89 params=params_state, obs_preprocessor_state=obs_preprocessor_state 90 )
91
[docs] 92 def compute_actions( 93 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 94 ) -> tuple[Action, PolicyExtraInfo]: 95 obs = sample_batch.obs 96 if self.normalize_obs: 97 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 98 99 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 100 101 if self.continuous_action: 102 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 103 else: 104 actions_dist = get_categorical_dist(raw_actions) 105 106 actions = actions_dist.sample(seed=key) 107 108 policy_extras = PyTreeDict( 109 # raw_action=raw_actions, 110 # logp=actions_dist.log_prob(actions) 111 ) 112 113 return actions, policy_extras
114
[docs] 115 def evaluate_actions( 116 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 117 ) -> tuple[Action, PolicyExtraInfo]: 118 obs = sample_batch.obs 119 if self.normalize_obs: 120 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 121 122 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 123 124 if self.continuous_action: 125 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 126 else: 127 actions_dist = get_categorical_dist(raw_actions) 128 129 actions = actions_dist.mode() 130 131 return actions, PyTreeDict()
132
[docs] 133 def loss( 134 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 135 ) -> LossDict: 136 obs = sample_batch.obs 137 if self.normalize_obs: 138 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 139 140 # mask invalid transitions at autoreset 141 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset) 142 143 # ======= critic ======= 144 vs = self.value_network.apply(agent_state.params.value_params, obs) 145 146 v_targets = sample_batch.extras.v_targets 147 148 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask) 149 150 # ====== actor ======= 151 152 # [T*B, A] 153 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs) 154 155 if self.continuous_action: 156 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1)) 157 else: 158 actions_dist = get_categorical_dist(raw_actions) 159 160 # [T*B] 161 actions_logp = actions_dist.log_prob(sample_batch.actions) 162 163 advantages = sample_batch.extras.advantages 164 165 # advantages: [T*B] 166 actor_loss = -(advantages * actions_logp).mean(where=mask) 167 # entropy: [T*B] 168 if self.continuous_action: 169 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask) 170 else: 171 actor_entropy = actions_dist.entropy().mean(where=mask) 172 173 return PyTreeDict( 174 actor_loss=actor_loss, 175 critic_loss=critic_loss, 176 actor_entropy=actor_entropy, 177 )
178
[docs] 179 def compute_values( 180 self, agent_state: AgentState, sample_batch: SampleBatch 181 ) -> chex.Array: 182 obs = sample_batch.obs 183 if self.normalize_obs: 184 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state) 185 186 return self.value_network.apply(agent_state.params.value_params, obs)
187 188
[docs] 189def make_mlp_a2c_agent( 190 action_space: Space, 191 actor_hidden_layer_sizes: tuple[int] = (256, 256), 192 critic_hidden_layer_sizes: tuple[int] = (256, 256), 193 normalize_obs: bool = False, 194 policy_obs_key: str = "", 195 value_obs_key: str = "", 196) -> A2CAgent: 197 if isinstance(action_space, Box): 198 action_size = action_space.shape[0] * 2 199 continuous_action = True 200 elif isinstance(action_space, Discrete): 201 action_size = action_space.n 202 continuous_action = False 203 else: 204 raise NotImplementedError(f"Unsupported action space: {action_space}") 205 206 policy_network = make_policy_network( 207 action_size=action_size, 208 hidden_layer_sizes=actor_hidden_layer_sizes, 209 obs_key=policy_obs_key, 210 ) 211 212 value_network = make_v_network( 213 hidden_layer_sizes=critic_hidden_layer_sizes, obs_key=value_obs_key 214 ) 215 216 if normalize_obs: 217 obs_preprocessor = running_statistics.normalize 218 else: 219 obs_preprocessor = None 220 221 return A2CAgent( 222 policy_network=policy_network, 223 value_network=value_network, 224 obs_preprocessor=obs_preprocessor, 225 continuous_action=continuous_action, 226 )
227 228
[docs] 229class A2CWorkflow(OnPolicyWorkflow):
[docs] 230 @classmethod 231 def name(cls): 232 return "A2C"
233 234 @classmethod 235 def _rescale_config(cls, config: DictConfig) -> None: 236 num_devices = jax.device_count() 237 238 if config.num_envs % num_devices != 0: 239 logger.warning( 240 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), " 241 f"rescale num_envs to {config.num_envs // num_devices}" 242 ) 243 if config.num_eval_envs % num_devices != 0: 244 logger.warning( 245 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), " 246 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}" 247 ) 248 249 config.num_envs = config.num_envs // num_devices 250 config.num_eval_envs = config.num_eval_envs // num_devices 251 # Note: batch_size = num_envs * rollout_length, no need to rescale again 252 253 @classmethod 254 def _build_from_config(cls, config: DictConfig): 255 max_episode_steps = config.env.max_episode_steps 256 257 env = create_env( 258 config.env, 259 episode_length=max_episode_steps, 260 parallel=config.num_envs, 261 autoreset_mode=AutoresetMode.ENVPOOL, 262 ) 263 264 agent = make_mlp_a2c_agent( 265 action_space=env.action_space, 266 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes, 267 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes, 268 normalize_obs=config.normalize_obs, 269 policy_obs_key=config.agent_network.policy_obs_key, 270 value_obs_key=config.agent_network.value_obs_key, 271 ) 272 273 if ( 274 config.optimizer.grad_clip_norm is not None 275 and config.optimizer.grad_clip_norm > 0 276 ): 277 optimizer = optax.chain( 278 optax.clip_by_global_norm(config.optimizer.grad_clip_norm), 279 optax.adam(config.optimizer.lr), 280 ) 281 else: 282 optimizer = optax.adam(config.optimizer.lr) 283 284 eval_env = create_env( 285 config.env, 286 episode_length=max_episode_steps, 287 parallel=config.num_eval_envs, 288 autoreset_mode=AutoresetMode.DISABLED, 289 ) 290 291 evaluator = Evaluator( 292 env=eval_env, 293 action_fn=agent.evaluate_actions, 294 max_episode_steps=max_episode_steps, 295 ) 296 297 return cls(env, agent, optimizer, evaluator, config) 298
[docs] 299 def step(self, state: State) -> tuple[MetricBase, State]: 300 key, rollout_key, learn_key = jax.random.split(state.key, num=3) 301 302 # trajectory: [T, #envs, ...] 303 trajectory, env_state = rollout( 304 self.env.step, 305 self.agent.compute_actions, 306 state.env_state, 307 state.agent_state, 308 rollout_key, 309 rollout_length=self.config.rollout_length, 310 env_extra_fields=("autoreset", "episode_return", "termination"), 311 ) 312 313 agent_state = state.agent_state 314 if agent_state.obs_preprocessor_state is not None: 315 agent_state = agent_state.replace( 316 obs_preprocessor_state=running_statistics.update( 317 agent_state.obs_preprocessor_state, 318 trajectory.obs, 319 dp_axis_name=self.dp_axis_name, 320 ) 321 ) 322 323 train_episode_return = average_episode_discount_return( 324 trajectory.extras.env_extras.episode_return, 325 trajectory.dones, 326 dp_axis_name=self.dp_axis_name, 327 ) 328 329 # ======== compute GAE ======= 330 _obs = jtu.tree_map( 331 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0), 332 trajectory.obs, 333 trajectory.next_obs, 334 ) 335 # concat [values, bootstrap_value] 336 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs)) 337 v_targets, advantages = compute_gae( 338 rewards=trajectory.rewards, 339 values=vs, 340 dones=trajectory.dones, 341 terminations=trajectory.extras.env_extras.termination, 342 gae_lambda=self.config.gae_lambda, 343 discount=self.config.discount, 344 ) 345 346 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets) 347 trajectory.extras.advantages = jax.lax.stop_gradient(advantages) 348 # [T,B,...] -> [T*B,...] 349 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory)) 350 # ============================ 351 352 def loss_fn(agent_state, sample_batch, key): 353 # learn all data from trajectory 354 loss_dict = self.agent.loss(agent_state, sample_batch, key) 355 loss_weights = self.config.loss_weights 356 loss = jnp.zeros(()) 357 for loss_key in loss_weights.keys(): 358 loss += loss_weights[loss_key] * loss_dict[loss_key] 359 360 return loss, loss_dict 361 362 update_fn = agent_gradient_update( 363 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True 364 ) 365 366 (loss, loss_dict), agent_state, opt_state = update_fn( 367 state.opt_state, agent_state, trajectory, learn_key 368 ) 369 370 # ======== update metrics ======== 371 372 sampled_timesteps = psum( 373 jnp.uint32(self.config.rollout_length * self.config.num_envs), 374 axis_name=self.dp_axis_name, 375 ) 376 sampled_epsiodes = psum( 377 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name 378 ) 379 380 workflow_metrics = state.metrics.replace( 381 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps, 382 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes, 383 iterations=state.metrics.iterations + 1, 384 ).all_reduce(dp_axis_name=self.dp_axis_name) 385 386 train_metrics = TrainMetric( 387 train_episode_return=train_episode_return, 388 loss=loss, 389 raw_loss_dict=loss_dict, 390 ).all_reduce(dp_axis_name=self.dp_axis_name) 391 392 return train_metrics, state.replace( 393 key=key, 394 metrics=workflow_metrics, 395 agent_state=agent_state, 396 env_state=env_state, 397 opt_state=opt_state, 398 )
399
[docs] 400 def learn(self, state: State) -> State: 401 one_step_timesteps = self.config.rollout_length * self.config.num_envs 402 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps) 403 404 start_iteration = state.metrics.iterations.tolist() 405 406 for i in range(start_iteration, num_iters): 407 train_metrics, state = self.step(state) 408 workflow_metrics = state.metrics 409 410 iters = i + 1 411 412 self.recorder.write(workflow_metrics.to_local_dict(), iters) 413 train_metric_data = train_metrics.to_local_dict() 414 if train_metrics.train_episode_return == MISSING_REWARD: 415 train_metric_data["train_episode_return"] = None 416 self.recorder.write(train_metric_data, iters) 417 418 if iters % self.config.eval_interval == 0 or iters == num_iters: 419 eval_metrics, state = self.evaluate(state) 420 self.recorder.write( 421 add_prefix(eval_metrics.to_local_dict(), "eval"), iters 422 ) 423 424 self.checkpoint_manager.save( 425 iters, 426 state, 427 force=iters == num_iters, 428 ) 429 430 return state