1import logging
2import math
3from typing import Any
4
5import chex
6import flax.linen as nn
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11from omegaconf import DictConfig
12
13from evorl.distributed import agent_gradient_update, psum
14from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
15from evorl.envs import AutoresetMode, create_env, Space, Box, Discrete
16from evorl.evaluators import Evaluator
17from evorl.metrics import TrainMetric, MetricBase
18from evorl.networks import make_policy_network, make_v_network
19from evorl.rollout import rollout
20from evorl.sample_batch import SampleBatch
21from evorl.types import (
22 MISSING_REWARD,
23 Action,
24 LossDict,
25 Params,
26 PolicyExtraInfo,
27 PyTreeData,
28 PyTreeDict,
29 State,
30 pytree_field,
31)
32from evorl.utils import running_statistics
33from evorl.utils.jax_utils import tree_get, tree_stop_gradient
34from evorl.utils.rl_toolkits import (
35 average_episode_discount_return,
36 compute_gae,
37 flatten_rollout_trajectory,
38)
39from evorl.workflows import OnPolicyWorkflow
40from evorl.recorders import add_prefix
41
42
43from evorl.agent import Agent, AgentState
44
45logger = logging.getLogger(__name__)
46
47
[docs]
48class A2CNetworkParams(PyTreeData):
49 """Contains training state for the learner."""
50
51 policy_params: Params
52 value_params: Params
53
54
[docs]
55class A2CAgent(Agent):
56 continuous_action: bool
57 policy_network: nn.Module # nn.Module is ok
58 value_network: nn.Module
59 obs_preprocessor: Any = pytree_field(default=None, static=True)
60
61 @property
62 def normalize_obs(self):
63 return self.obs_preprocessor is not None
64
[docs]
65 def init(
66 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
67 ) -> AgentState:
68 policy_key, value_key = jax.random.split(key, 2)
69
70 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
71
72 policy_params = self.policy_network.init(policy_key, dummy_obs)
73
74 value_params = self.value_network.init(value_key, dummy_obs)
75
76 params_state = A2CNetworkParams(
77 policy_params=policy_params, value_params=value_params
78 )
79
80 if self.normalize_obs:
81 # Note: statistics are broadcasted to [T*B]
82 obs_preprocessor_state = running_statistics.init_state(
83 tree_get(dummy_obs, 0)
84 )
85 else:
86 obs_preprocessor_state = None
87
88 return AgentState(
89 params=params_state, obs_preprocessor_state=obs_preprocessor_state
90 )
91
[docs]
92 def compute_actions(
93 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
94 ) -> tuple[Action, PolicyExtraInfo]:
95 obs = sample_batch.obs
96 if self.normalize_obs:
97 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
98
99 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
100
101 if self.continuous_action:
102 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
103 else:
104 actions_dist = get_categorical_dist(raw_actions)
105
106 actions = actions_dist.sample(seed=key)
107
108 policy_extras = PyTreeDict(
109 # raw_action=raw_actions,
110 # logp=actions_dist.log_prob(actions)
111 )
112
113 return actions, policy_extras
114
[docs]
115 def evaluate_actions(
116 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
117 ) -> tuple[Action, PolicyExtraInfo]:
118 obs = sample_batch.obs
119 if self.normalize_obs:
120 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
121
122 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
123
124 if self.continuous_action:
125 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
126 else:
127 actions_dist = get_categorical_dist(raw_actions)
128
129 actions = actions_dist.mode()
130
131 return actions, PyTreeDict()
132
[docs]
133 def loss(
134 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
135 ) -> LossDict:
136 obs = sample_batch.obs
137 if self.normalize_obs:
138 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
139
140 # mask invalid transitions at autoreset
141 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset)
142
143 # ======= critic =======
144 vs = self.value_network.apply(agent_state.params.value_params, obs)
145
146 v_targets = sample_batch.extras.v_targets
147
148 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask)
149
150 # ====== actor =======
151
152 # [T*B, A]
153 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
154
155 if self.continuous_action:
156 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
157 else:
158 actions_dist = get_categorical_dist(raw_actions)
159
160 # [T*B]
161 actions_logp = actions_dist.log_prob(sample_batch.actions)
162
163 advantages = sample_batch.extras.advantages
164
165 # advantages: [T*B]
166 actor_loss = -(advantages * actions_logp).mean(where=mask)
167 # entropy: [T*B]
168 if self.continuous_action:
169 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask)
170 else:
171 actor_entropy = actions_dist.entropy().mean(where=mask)
172
173 return PyTreeDict(
174 actor_loss=actor_loss,
175 critic_loss=critic_loss,
176 actor_entropy=actor_entropy,
177 )
178
[docs]
179 def compute_values(
180 self, agent_state: AgentState, sample_batch: SampleBatch
181 ) -> chex.Array:
182 obs = sample_batch.obs
183 if self.normalize_obs:
184 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
185
186 return self.value_network.apply(agent_state.params.value_params, obs)
187
188
[docs]
189def make_mlp_a2c_agent(
190 action_space: Space,
191 actor_hidden_layer_sizes: tuple[int] = (256, 256),
192 critic_hidden_layer_sizes: tuple[int] = (256, 256),
193 normalize_obs: bool = False,
194 policy_obs_key: str = "",
195 value_obs_key: str = "",
196) -> A2CAgent:
197 if isinstance(action_space, Box):
198 action_size = action_space.shape[0] * 2
199 continuous_action = True
200 elif isinstance(action_space, Discrete):
201 action_size = action_space.n
202 continuous_action = False
203 else:
204 raise NotImplementedError(f"Unsupported action space: {action_space}")
205
206 policy_network = make_policy_network(
207 action_size=action_size,
208 hidden_layer_sizes=actor_hidden_layer_sizes,
209 obs_key=policy_obs_key,
210 )
211
212 value_network = make_v_network(
213 hidden_layer_sizes=critic_hidden_layer_sizes, obs_key=value_obs_key
214 )
215
216 if normalize_obs:
217 obs_preprocessor = running_statistics.normalize
218 else:
219 obs_preprocessor = None
220
221 return A2CAgent(
222 policy_network=policy_network,
223 value_network=value_network,
224 obs_preprocessor=obs_preprocessor,
225 continuous_action=continuous_action,
226 )
227
228
[docs]
229class A2CWorkflow(OnPolicyWorkflow):
[docs]
230 @classmethod
231 def name(cls):
232 return "A2C"
233
234 @classmethod
235 def _rescale_config(cls, config: DictConfig) -> None:
236 num_devices = jax.device_count()
237
238 if config.num_envs % num_devices != 0:
239 logger.warning(
240 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), "
241 f"rescale num_envs to {config.num_envs // num_devices}"
242 )
243 if config.num_eval_envs % num_devices != 0:
244 logger.warning(
245 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), "
246 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}"
247 )
248
249 config.num_envs = config.num_envs // num_devices
250 config.num_eval_envs = config.num_eval_envs // num_devices
251 # Note: batch_size = num_envs * rollout_length, no need to rescale again
252
253 @classmethod
254 def _build_from_config(cls, config: DictConfig):
255 max_episode_steps = config.env.max_episode_steps
256
257 env = create_env(
258 config.env,
259 episode_length=max_episode_steps,
260 parallel=config.num_envs,
261 autoreset_mode=AutoresetMode.ENVPOOL,
262 )
263
264 agent = make_mlp_a2c_agent(
265 action_space=env.action_space,
266 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
267 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
268 normalize_obs=config.normalize_obs,
269 policy_obs_key=config.agent_network.policy_obs_key,
270 value_obs_key=config.agent_network.value_obs_key,
271 )
272
273 if (
274 config.optimizer.grad_clip_norm is not None
275 and config.optimizer.grad_clip_norm > 0
276 ):
277 optimizer = optax.chain(
278 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
279 optax.adam(config.optimizer.lr),
280 )
281 else:
282 optimizer = optax.adam(config.optimizer.lr)
283
284 eval_env = create_env(
285 config.env,
286 episode_length=max_episode_steps,
287 parallel=config.num_eval_envs,
288 autoreset_mode=AutoresetMode.DISABLED,
289 )
290
291 evaluator = Evaluator(
292 env=eval_env,
293 action_fn=agent.evaluate_actions,
294 max_episode_steps=max_episode_steps,
295 )
296
297 return cls(env, agent, optimizer, evaluator, config)
298
[docs]
299 def step(self, state: State) -> tuple[MetricBase, State]:
300 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
301
302 # trajectory: [T, #envs, ...]
303 trajectory, env_state = rollout(
304 self.env.step,
305 self.agent.compute_actions,
306 state.env_state,
307 state.agent_state,
308 rollout_key,
309 rollout_length=self.config.rollout_length,
310 env_extra_fields=("autoreset", "episode_return", "termination"),
311 )
312
313 agent_state = state.agent_state
314 if agent_state.obs_preprocessor_state is not None:
315 agent_state = agent_state.replace(
316 obs_preprocessor_state=running_statistics.update(
317 agent_state.obs_preprocessor_state,
318 trajectory.obs,
319 dp_axis_name=self.dp_axis_name,
320 )
321 )
322
323 train_episode_return = average_episode_discount_return(
324 trajectory.extras.env_extras.episode_return,
325 trajectory.dones,
326 dp_axis_name=self.dp_axis_name,
327 )
328
329 # ======== compute GAE =======
330 _obs = jtu.tree_map(
331 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0),
332 trajectory.obs,
333 trajectory.next_obs,
334 )
335 # concat [values, bootstrap_value]
336 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs))
337 v_targets, advantages = compute_gae(
338 rewards=trajectory.rewards,
339 values=vs,
340 dones=trajectory.dones,
341 terminations=trajectory.extras.env_extras.termination,
342 gae_lambda=self.config.gae_lambda,
343 discount=self.config.discount,
344 )
345
346 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets)
347 trajectory.extras.advantages = jax.lax.stop_gradient(advantages)
348 # [T,B,...] -> [T*B,...]
349 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory))
350 # ============================
351
352 def loss_fn(agent_state, sample_batch, key):
353 # learn all data from trajectory
354 loss_dict = self.agent.loss(agent_state, sample_batch, key)
355 loss_weights = self.config.loss_weights
356 loss = jnp.zeros(())
357 for loss_key in loss_weights.keys():
358 loss += loss_weights[loss_key] * loss_dict[loss_key]
359
360 return loss, loss_dict
361
362 update_fn = agent_gradient_update(
363 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True
364 )
365
366 (loss, loss_dict), agent_state, opt_state = update_fn(
367 state.opt_state, agent_state, trajectory, learn_key
368 )
369
370 # ======== update metrics ========
371
372 sampled_timesteps = psum(
373 jnp.uint32(self.config.rollout_length * self.config.num_envs),
374 axis_name=self.dp_axis_name,
375 )
376 sampled_epsiodes = psum(
377 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
378 )
379
380 workflow_metrics = state.metrics.replace(
381 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
382 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
383 iterations=state.metrics.iterations + 1,
384 ).all_reduce(dp_axis_name=self.dp_axis_name)
385
386 train_metrics = TrainMetric(
387 train_episode_return=train_episode_return,
388 loss=loss,
389 raw_loss_dict=loss_dict,
390 ).all_reduce(dp_axis_name=self.dp_axis_name)
391
392 return train_metrics, state.replace(
393 key=key,
394 metrics=workflow_metrics,
395 agent_state=agent_state,
396 env_state=env_state,
397 opt_state=opt_state,
398 )
399
[docs]
400 def learn(self, state: State) -> State:
401 one_step_timesteps = self.config.rollout_length * self.config.num_envs
402 num_iters = math.ceil(self.config.total_timesteps / one_step_timesteps)
403
404 start_iteration = state.metrics.iterations.tolist()
405
406 for i in range(start_iteration, num_iters):
407 train_metrics, state = self.step(state)
408 workflow_metrics = state.metrics
409
410 iters = i + 1
411
412 self.recorder.write(workflow_metrics.to_local_dict(), iters)
413 train_metric_data = train_metrics.to_local_dict()
414 if train_metrics.train_episode_return == MISSING_REWARD:
415 train_metric_data["train_episode_return"] = None
416 self.recorder.write(train_metric_data, iters)
417
418 if iters % self.config.eval_interval == 0 or iters == num_iters:
419 eval_metrics, state = self.evaluate(state)
420 self.recorder.write(
421 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
422 )
423
424 self.checkpoint_manager.save(
425 iters,
426 state,
427 force=iters == num_iters,
428 )
429
430 return state