1import logging
2from functools import partial
3from omegaconf import DictConfig
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10
11from evorl.agent import AgentState
12from evorl.envs import Space, create_env, AutoresetMode, Box, Discrete
13from evorl.evaluators import Evaluator
14from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
15from evorl.distributed import agent_gradient_update, psum
16from evorl.metrics import TrainMetric, MetricBase
17from evorl.networks import make_policy_network, make_v_network
18from evorl.rollout import rollout
19from evorl.sample_batch import SampleBatch
20from evorl.types import PyTreeDict, State, LossDict
21from evorl.utils import running_statistics
22from evorl.utils.jax_utils import tree_stop_gradient, scan_and_mean
23from evorl.utils.rl_toolkits import (
24 average_episode_discount_return,
25 compute_gae,
26 flatten_rollout_trajectory,
27 approximate_kl,
28)
29
30from evorl.algorithms.ppo import PPOWorkflow, PPOAgent
31
32logger = logging.getLogger(__name__)
33
34
[docs]
35class ParamPPOAgent(PPOAgent):
[docs]
36 def init(
37 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
38 ) -> AgentState:
39 agent_state = super().init(obs_space, action_space, key)
40
41 return agent_state.replace(
42 extra_state=PyTreeDict(clip_epsilon=jnp.float32(self.clip_epsilon))
43 )
44
[docs]
45 def loss(
46 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
47 ) -> LossDict:
48 obs = sample_batch.obs
49 if self.normalize_obs:
50 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
51
52 # mask invalid transitions at autoreset
53 mask = jnp.logical_not(sample_batch.extras.env_extras.autoreset)
54
55 # ======= critic =======
56 vs = self.value_network.apply(agent_state.params.value_params, obs)
57
58 v_targets = sample_batch.extras.v_targets
59
60 critic_loss = optax.squared_error(vs, v_targets).mean(where=mask)
61
62 # ====== actor =======
63
64 # [T*B, A]
65 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
66
67 if self.continuous_action:
68 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
69 else:
70 actions_dist = get_categorical_dist(raw_actions)
71
72 # [T*B]
73 actions_logp = actions_dist.log_prob(sample_batch.actions)
74 behavior_actions_logp = sample_batch.extras.policy_extras.logp
75
76 advantages = sample_batch.extras.advantages
77 if self.normalize_gae:
78 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
79
80 logrho = actions_logp - behavior_actions_logp
81 rho = jnp.exp(logrho)
82
83 # advantages: [T*B]
84 policy_sorrogate_loss1 = rho * advantages
85 policy_sorrogate_loss2 = (
86 jnp.clip(
87 rho,
88 1 - agent_state.extra_state.clip_epsilon,
89 1 + agent_state.extra_state.clip_epsilon,
90 )
91 * advantages
92 )
93 actor_loss = -jnp.minimum(policy_sorrogate_loss1, policy_sorrogate_loss2).mean(
94 where=mask
95 )
96
97 # entropy: [T*B]
98 if self.continuous_action:
99 actor_entropy = actions_dist.entropy(seed=key).mean(where=mask)
100 else:
101 actor_entropy = actions_dist.entropy().mean(where=mask)
102
103 approx_kl = approximate_kl(logrho)
104
105 return PyTreeDict(
106 actor_loss=actor_loss,
107 critic_loss=critic_loss,
108 actor_entropy=actor_entropy,
109 approx_kl=approx_kl,
110 )
111
112
[docs]
113def make_mlp_ppo_agent(
114 action_space: Space,
115 clip_epsilon: float = 0.2,
116 actor_hidden_layer_sizes: tuple[int] = (256, 256),
117 critic_hidden_layer_sizes: tuple[int] = (256, 256),
118 normalize_obs: bool = False,
119):
120 if isinstance(action_space, Box):
121 action_size = action_space.shape[0] * 2
122 continuous_action = True
123 elif isinstance(action_space, Discrete):
124 action_size = action_space.n
125 continuous_action = False
126 else:
127 raise NotImplementedError(f"Unsupported action space: {action_space}")
128
129 policy_network = make_policy_network(
130 action_size=action_size,
131 hidden_layer_sizes=actor_hidden_layer_sizes,
132 )
133
134 value_network = make_v_network(hidden_layer_sizes=critic_hidden_layer_sizes)
135
136 if normalize_obs:
137 obs_preprocessor = running_statistics.normalize
138 else:
139 obs_preprocessor = None
140
141 return ParamPPOAgent(
142 continuous_action=continuous_action,
143 policy_network=policy_network,
144 value_network=value_network,
145 obs_preprocessor=obs_preprocessor,
146 clip_epsilon=clip_epsilon,
147 )
148
149
[docs]
150class ParamPPOWorkflow(PPOWorkflow):
[docs]
151 @classmethod
152 def name(cls):
153 return "ParamPPO"
154
155 @classmethod
156 def _build_from_config(cls, config: DictConfig):
157 max_episode_steps = config.env.max_episode_steps
158
159 env = create_env(
160 config.env,
161 episode_length=max_episode_steps,
162 parallel=config.num_envs,
163 autoreset_mode=AutoresetMode.ENVPOOL,
164 )
165
166 agent = make_mlp_ppo_agent(
167 action_space=env.action_space,
168 clip_epsilon=config.clip_epsilon,
169 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
170 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
171 normalize_obs=config.normalize_obs,
172 )
173
174 if (
175 config.optimizer.grad_clip_norm is not None
176 and config.optimizer.grad_clip_norm > 0
177 ):
178 optimizer = optax.chain(
179 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
180 optax.adam(config.optimizer.lr),
181 )
182 else:
183 optimizer = optax.adam(config.optimizer.lr)
184
185 eval_env = create_env(
186 config.env,
187 episode_length=max_episode_steps,
188 parallel=config.num_eval_envs,
189 autoreset_mode=AutoresetMode.DISABLED,
190 )
191
192 one_step_rollout_steps = config.num_envs * config.rollout_length
193 if one_step_rollout_steps % config.minibatch_size != 0:
194 logger.warning(
195 f"minibatch_size ({config.minibath_size} cannot divides num_envs*rollout_length)"
196 )
197
198 evaluator = Evaluator(
199 env=eval_env,
200 action_fn=agent.evaluate_actions,
201 max_episode_steps=max_episode_steps,
202 )
203
204 return cls(env, agent, optimizer, evaluator, config)
205
[docs]
206 def setup(self, key: chex.PRNGKey) -> State:
207 state = super().setup(key)
208
209 return state.replace(
210 hp_state=PyTreeDict(
211 gae_lambda_g=-jnp.log(1 - jnp.float32(self.config.gae_lambda)),
212 discount_g=-jnp.log(
213 1 - jnp.float32(self.config.discount)
214 ), # discount = 1 - exp(-g)
215 actor_loss_weight=jnp.float32(self.config.loss_weights.actor_loss),
216 critic_loss_weight=jnp.float32(self.config.loss_weights.critic_loss),
217 entropy_loss_weight=jnp.float32(self.config.loss_weights.actor_entropy),
218 )
219 )
220
[docs]
221 def step(self, state: State) -> tuple[MetricBase, State]:
222 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
223
224 # trajectory: [T, #envs, ...]
225 trajectory, env_state = rollout(
226 self.env.step,
227 self.agent.compute_actions,
228 state.env_state,
229 state.agent_state,
230 rollout_key,
231 rollout_length=self.config.rollout_length,
232 env_extra_fields=("autoreset", "episode_return", "termination"),
233 )
234
235 agent_state = state.agent_state
236 if agent_state.obs_preprocessor_state is not None:
237 agent_state = agent_state.replace(
238 obs_preprocessor_state=running_statistics.update(
239 agent_state.obs_preprocessor_state,
240 trajectory.obs,
241 dp_axis_name=self.dp_axis_name,
242 )
243 )
244
245 train_episode_return = average_episode_discount_return(
246 trajectory.extras.env_extras.episode_return,
247 trajectory.dones,
248 dp_axis_name=self.dp_axis_name,
249 )
250
251 # ======== compute GAE =======
252 _obs = jtu.tree_map(
253 lambda obs, next_obs: jnp.concatenate([obs, next_obs[-1:]], axis=0),
254 trajectory.obs,
255 trajectory.next_obs,
256 )
257 # concat [values, bootstrap_value]
258 vs = self.agent.compute_values(state.agent_state, SampleBatch(obs=_obs))
259
260 gae_lambda = 1 - jnp.exp(-state.hp_state.gae_lambda_g)
261 discount = 1 - jnp.exp(-state.hp_state.discount_g)
262
263 v_targets, advantages = compute_gae(
264 rewards=trajectory.rewards, # peb_rewards
265 values=vs,
266 dones=trajectory.dones,
267 terminations=trajectory.extras.env_extras.termination,
268 gae_lambda=gae_lambda,
269 discount=discount,
270 )
271 trajectory.extras.v_targets = jax.lax.stop_gradient(v_targets)
272 trajectory.extras.advantages = jax.lax.stop_gradient(advantages)
273 # [T,B,...] -> [T*B,...]
274 trajectory = tree_stop_gradient(flatten_rollout_trajectory(trajectory))
275 # ============================
276
277 def loss_fn(agent_state, sample_batch, key):
278 # learn all data from trajectory
279 loss_dict = self.agent.loss(agent_state, sample_batch, key)
280 loss_weights = dict(
281 actor_loss=state.hp_state.actor_loss_weight,
282 critic_loss=state.hp_state.critic_loss_weight,
283 actor_entropy=state.hp_state.entropy_loss_weight,
284 )
285 loss = jnp.zeros(())
286 for loss_key in loss_weights.keys():
287 loss += loss_weights[loss_key] * loss_dict[loss_key]
288
289 return loss, loss_dict
290
291 update_fn = agent_gradient_update(
292 loss_fn, self.optimizer, dp_axis_name=self.dp_axis_name, has_aux=True
293 )
294
295 num_minibatches = (
296 self.config.rollout_length
297 * self.config.num_envs
298 // self.config.minibatch_size
299 )
300
301 def _get_shuffled_minibatch(perm_key, x):
302 x = x[jax.random.permutation(perm_key, x.shape[0])][
303 : num_minibatches * self.config.minibatch_size
304 ]
305 return x.reshape(num_minibatches, self.config.minibatch_size, *x.shape[1:])
306
307 def minibatch_step(carry, trajectory):
308 opt_state, agent_state, key = carry
309 key, learn_key = jax.random.split(key)
310
311 (loss, loss_dict), agent_state, opt_state = update_fn(
312 opt_state, agent_state, trajectory, learn_key
313 )
314
315 return (opt_state, agent_state, key), (loss, loss_dict)
316
317 def epoch_step(carry, _):
318 opt_state, agent_state, key = carry
319 perm_key, learn_key = jax.random.split(key, num=2)
320
321 (opt_state, agent_state, key), (loss, loss_dict) = scan_and_mean(
322 minibatch_step,
323 (opt_state, agent_state, learn_key),
324 jtu.tree_map(partial(_get_shuffled_minibatch, perm_key), trajectory),
325 length=num_minibatches,
326 )
327
328 return (opt_state, agent_state, key), (loss, loss_dict)
329
330 # loss_list: [reuse_rollout_epochs, num_minibatches]
331 (opt_state, agent_state, _), (loss, loss_dict) = scan_and_mean(
332 epoch_step,
333 (state.opt_state, agent_state, learn_key),
334 None,
335 length=self.config.reuse_rollout_epochs,
336 )
337
338 # ======== update metrics ========
339
340 sampled_timesteps = psum(
341 jnp.uint32(self.config.rollout_length * self.config.num_envs),
342 axis_name=self.dp_axis_name,
343 )
344 sampled_epsiodes = psum(
345 trajectory.dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
346 )
347
348 workflow_metrics = state.metrics.replace(
349 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
350 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
351 iterations=state.metrics.iterations + 1,
352 ).all_reduce(dp_axis_name=self.dp_axis_name)
353
354 train_metrics = TrainMetric(
355 train_episode_return=train_episode_return,
356 loss=loss,
357 raw_loss_dict=loss_dict,
358 ).all_reduce(dp_axis_name=self.dp_axis_name)
359
360 return train_metrics, state.replace(
361 key=key,
362 metrics=workflow_metrics,
363 agent_state=agent_state,
364 env_state=env_state,
365 opt_state=opt_state,
366 )