1import logging
2from omegaconf import DictConfig
3
4import chex
5import jax
6import jax.numpy as jnp
7import optax
8
9from evorl.agent import AgentState
10from evorl.envs import Space, Box, create_env, AutoresetMode
11from evorl.evaluators import Evaluator
12from evorl.distributed import agent_gradient_update, psum
13from evorl.distribution import get_tanh_norm_dist
14from evorl.metrics import MetricBase
15from evorl.rollout import rollout
16from evorl.replay_buffers import ReplayBuffer
17from evorl.types import PyTreeDict, State, LossDict
18from evorl.sample_batch import SampleBatch
19from evorl.networks import make_policy_network, make_q_network
20from evorl.utils import running_statistics
21from evorl.utils.jax_utils import scan_and_mean, tree_stop_gradient
22from evorl.utils.rl_toolkits import flatten_rollout_trajectory, soft_target_update
23
24from evorl.algorithms.offpolicy_utils import clean_trajectory, OffPolicyWorkflowTemplate
25from evorl.algorithms.sac import SACTrainMetric, SACAgent
26
27
28logger = logging.getLogger(__name__)
29
30
[docs]
31class ParamSACTrainMetric(SACTrainMetric):
32 trajectory: SampleBatch = None
33
34
[docs]
35class ParamSACAgent(SACAgent):
36 """SAC agent with parameterized hyperparameters."""
37
[docs]
38 def init(
39 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
40 ) -> AgentState:
41 agent_state = super().init(obs_space, action_space, key)
42
43 return agent_state.replace(
44 extra_state=agent_state.extra_state.replace(
45 discount_g=-jnp.log(
46 1 - jnp.float32(self.discount)
47 ), # discount = 1 - exp(-g)
48 )
49 )
50
[docs]
51 def critic_loss(
52 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
53 ) -> LossDict:
54 obs = sample_batch.obs
55 next_obs = sample_batch.extras.env_extras.ori_obs
56 if self.normalize_obs:
57 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
58 next_obs = self.obs_preprocessor(
59 next_obs, agent_state.obs_preprocessor_state
60 )
61
62 discounts = (1 - jnp.exp(-agent_state.extra_state.discount_g)) * (
63 1 - sample_batch.extras.env_extras.termination
64 )
65
66 alpha = jnp.exp(agent_state.params.log_alpha)
67
68 # [B, 2]
69 qs = self.critic_network.apply(
70 agent_state.params.critic_params, obs, sample_batch.actions
71 )
72
73 next_raw_actions = self.actor_network.apply(
74 agent_state.params.actor_params, next_obs
75 )
76 next_actions_dist = get_tanh_norm_dist(*jnp.split(next_raw_actions, 2, axis=-1))
77 next_actions = next_actions_dist.sample(seed=key)
78 next_actions_logp = next_actions_dist.log_prob(next_actions)
79 # [B, 2]
80 next_qs = self.critic_network.apply(
81 agent_state.params.target_critic_params, next_obs, next_actions
82 )
83 qs_target = sample_batch.rewards + discounts * (
84 jnp.min(next_qs, axis=-1) - alpha * next_actions_logp
85 )
86 qs_target = jnp.broadcast_to(qs_target[..., None], (*qs_target.shape, 2))
87
88 q_loss = optax.squared_error(qs, qs_target).sum(-1).mean()
89 return PyTreeDict(critic_loss=q_loss, q_value=qs.mean())
90
91
[docs]
92def make_mlp_sac_agent(
93 action_space: Space,
94 critic_hidden_layer_sizes: tuple[int] = (256, 256),
95 actor_hidden_layer_sizes: tuple[int] = (256, 256),
96 init_alpha: float = 1.0,
97 discount: float = 0.99,
98 normalize_obs: bool = False,
99):
100 if isinstance(action_space, Box):
101 action_size = action_space.shape[0] * 2
102 else:
103 raise NotImplementedError(f"Unsupported action space: {action_space}")
104
105 actor_network = make_policy_network(
106 action_size=action_size, # mean+std
107 hidden_layer_sizes=actor_hidden_layer_sizes,
108 )
109
110 critic_network = make_q_network(
111 n_stack=2,
112 hidden_layer_sizes=critic_hidden_layer_sizes,
113 )
114
115 if normalize_obs:
116 obs_preprocessor = running_statistics.normalize
117 else:
118 obs_preprocessor = None
119
120 return ParamSACAgent(
121 critic_network=critic_network,
122 actor_network=actor_network,
123 obs_preprocessor=obs_preprocessor,
124 init_alpha=init_alpha,
125 discount=discount,
126 )
127
128
[docs]
129class ParamSACWorkflow(OffPolicyWorkflowTemplate):
130 """Workflow for ParamSAC.
131
132 Note: This workflow can only work with PBTParamSACWorkflow, since the replay buffer is initialized and managed by PBT externally.
133 """
134
[docs]
135 @classmethod
136 def name(cls):
137 return "ParamSAC"
138
139 @classmethod
140 def _build_from_config(cls, config: DictConfig):
141 env = create_env(
142 config.env,
143 episode_length=config.env.max_episode_steps,
144 parallel=config.num_envs,
145 autoreset_mode=AutoresetMode.NORMAL,
146 record_ori_obs=True,
147 )
148
149 agent = make_mlp_sac_agent(
150 action_space=env.action_space,
151 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
152 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
153 init_alpha=config.alpha,
154 discount=config.discount,
155 normalize_obs=config.normalize_obs,
156 )
157
158 # TODO: use different lr for critic and actor
159 if (
160 config.optimizer.grad_clip_norm is not None
161 and config.optimizer.grad_clip_norm > 0
162 ):
163 optimizer = optax.chain(
164 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
165 optax.adam(config.optimizer.lr),
166 )
167 else:
168 optimizer = optax.adam(config.optimizer.lr)
169
170 replay_buffer = ReplayBuffer(
171 capacity=config.replay_buffer_capacity,
172 min_sample_timesteps=config.batch_size,
173 sample_batch_size=config.batch_size,
174 )
175
176 eval_env = create_env(
177 config.env,
178 episode_length=config.env.max_episode_steps,
179 parallel=config.num_eval_envs,
180 autoreset_mode=AutoresetMode.DISABLED,
181 )
182
183 evaluator = Evaluator(
184 env=eval_env,
185 action_fn=agent.evaluate_actions,
186 max_episode_steps=config.env.max_episode_steps,
187 )
188
189 return cls(
190 env,
191 agent,
192 optimizer,
193 evaluator,
194 replay_buffer,
195 config,
196 )
197
198 def _setup_agent_and_optimizer(
199 self, key: chex.PRNGKey
200 ) -> tuple[AgentState, chex.ArrayTree]:
201 agent_state = self.agent.init(self.env.obs_space, self.env.action_space, key)
202 opt_state = PyTreeDict(
203 dict(
204 actor=self.optimizer.init(agent_state.params.actor_params),
205 critic=self.optimizer.init(agent_state.params.critic_params),
206 )
207 )
208
209 return agent_state, opt_state
210
[docs]
211 def setup(self, key: chex.PRNGKey) -> State:
212 key, agent_key, env_key = jax.random.split(key, 3)
213
214 agent_state, opt_state = self._setup_agent_and_optimizer(agent_key)
215 workflow_metrics = self._setup_workflow_metrics()
216 env_state = self.env.reset(env_key)
217
218 state = State(
219 key=key,
220 metrics=workflow_metrics,
221 agent_state=agent_state,
222 env_state=env_state,
223 opt_state=opt_state,
224 replay_buffer_state=None, # init externally
225 hp_state=PyTreeDict(
226 actor_loss_weight=jnp.float32(1.0),
227 critic_loss_weight=jnp.float32(1.0),
228 ),
229 )
230
231 return state.replace()
232
[docs]
233 def step(self, state: State) -> tuple[MetricBase, State]:
234 key, rollout_key, learn_key = jax.random.split(state.key, num=3)
235
236 # the trajectory [T, B, ...]
237 trajectory, env_state = rollout(
238 env_fn=self.env.step,
239 action_fn=self.agent.compute_actions,
240 env_state=state.env_state,
241 agent_state=state.agent_state,
242 key=rollout_key,
243 rollout_length=self.config.rollout_length,
244 env_extra_fields=("ori_obs", "termination"),
245 )
246
247 trajectory_dones = trajectory.dones
248 trajectory = clean_trajectory(trajectory)
249 trajectory = flatten_rollout_trajectory(trajectory)
250 trajectory = tree_stop_gradient(trajectory)
251
252 agent_state = state.agent_state
253 if agent_state.obs_preprocessor_state is not None:
254 agent_state = agent_state.replace(
255 obs_preprocessor_state=running_statistics.update(
256 agent_state.obs_preprocessor_state,
257 trajectory.obs,
258 dp_axis_name=self.dp_axis_name,
259 )
260 )
261
262 # Here replay_buffer_state is read-only,
263 # we save the data externally instead
264 replay_buffer_state = state.replay_buffer_state
265
266 def critic_loss_fn(agent_state, sample_batch, key):
267 loss_dict = self.agent.critic_loss(agent_state, sample_batch, key)
268
269 loss = loss_dict.critic_loss * state.hp_state.critic_loss_weight
270 return loss, loss_dict
271
272 def actor_loss_fn(agent_state, sample_batch, key):
273 loss_dict = self.agent.actor_loss(agent_state, sample_batch, key)
274
275 loss = loss_dict.actor_loss * state.hp_state.actor_loss_weight
276 return loss, loss_dict
277
278 critic_update_fn = agent_gradient_update(
279 critic_loss_fn,
280 self.optimizer,
281 dp_axis_name=self.dp_axis_name,
282 has_aux=True,
283 attach_fn=lambda agent_state, critic_params: agent_state.replace(
284 params=agent_state.params.replace(critic_params=critic_params)
285 ),
286 detach_fn=lambda agent_state: agent_state.params.critic_params,
287 )
288
289 actor_update_fn = agent_gradient_update(
290 actor_loss_fn,
291 self.optimizer,
292 dp_axis_name=self.dp_axis_name,
293 has_aux=True,
294 attach_fn=lambda agent_state, actor_params: agent_state.replace(
295 params=agent_state.params.replace(actor_params=actor_params)
296 ),
297 detach_fn=lambda agent_state: agent_state.params.actor_params,
298 )
299
300 def _sample_and_update_fn(carry, unused_t):
301 key, agent_state, opt_state = carry
302
303 critic_opt_state = opt_state.critic
304 actor_opt_state = opt_state.actor
305
306 key, critic_key, actor_key, rb_key = jax.random.split(key, num=4)
307
308 if self.config.actor_update_interval - 1 > 0:
309
310 def _sample_and_update_critic_fn(carry, unused_t):
311 key, agent_state, critic_opt_state = carry
312
313 key, rb_key, critic_key = jax.random.split(key, num=3)
314 # it's safe to use read-only replay_buffer_state here.
315 sample_batch = self.replay_buffer.sample(
316 replay_buffer_state, rb_key
317 )
318
319 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
320 critic_update_fn(
321 critic_opt_state, agent_state, sample_batch, critic_key
322 )
323 )
324
325 return (key, agent_state, critic_opt_state), None
326
327 key, critic_multiple_update_key = jax.random.split(key)
328
329 (_, agent_state, critic_opt_state), _ = jax.lax.scan(
330 _sample_and_update_critic_fn,
331 (critic_multiple_update_key, agent_state, critic_opt_state),
332 (),
333 length=self.config.actor_update_interval - 1,
334 )
335
336 sample_batch = self.replay_buffer.sample(replay_buffer_state, rb_key)
337
338 (critic_loss, critic_loss_dict), agent_state, critic_opt_state = (
339 critic_update_fn(
340 critic_opt_state, agent_state, sample_batch, critic_key
341 )
342 )
343
344 (actor_loss, actor_loss_dict), agent_state, actor_opt_state = (
345 actor_update_fn(actor_opt_state, agent_state, sample_batch, actor_key)
346 )
347
348 opt_state = opt_state.replace(
349 actor=actor_opt_state, critic=critic_opt_state
350 )
351
352 res = (
353 critic_loss,
354 actor_loss,
355 critic_loss_dict,
356 actor_loss_dict,
357 )
358
359 target_critic_params = soft_target_update(
360 agent_state.params.target_critic_params,
361 agent_state.params.critic_params,
362 self.config.tau,
363 )
364 agent_state = agent_state.replace(
365 params=agent_state.params.replace(
366 target_critic_params=target_critic_params
367 )
368 )
369
370 return (key, agent_state, opt_state), res
371
372 (
373 (_, agent_state, opt_state),
374 (
375 critic_loss,
376 actor_loss,
377 critic_loss_dict,
378 actor_loss_dict,
379 ),
380 ) = scan_and_mean(
381 _sample_and_update_fn,
382 (learn_key, agent_state, state.opt_state),
383 (),
384 length=self.config.num_updates_per_iter,
385 )
386
387 train_metrics = ParamSACTrainMetric(
388 actor_loss=actor_loss,
389 critic_loss=critic_loss,
390 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
391 trajectory=trajectory,
392 ).all_reduce(dp_axis_name=self.dp_axis_name)
393
394 # calculate the number of timestep
395 sampled_timesteps = psum(
396 jnp.uint32(self.config.rollout_length * self.config.num_envs),
397 axis_name=self.dp_axis_name,
398 )
399 sampled_epsiodes = psum(
400 trajectory_dones.sum().astype(jnp.uint32), axis_name=self.dp_axis_name
401 )
402
403 # iterations is the number of updates of the agent
404 workflow_metrics = state.metrics.replace(
405 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
406 sampled_episodes=state.metrics.sampled_episodes + sampled_epsiodes,
407 iterations=state.metrics.iterations + 1,
408 ).all_reduce(dp_axis_name=self.dp_axis_name)
409
410 return train_metrics, state.replace(
411 key=key,
412 metrics=workflow_metrics,
413 agent_state=agent_state,
414 env_state=env_state,
415 # replay_buffer_state=replay_buffer_state,
416 opt_state=opt_state,
417 )