1import math
2import numpy as np
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4from omegaconf import DictConfig
5
6import chex
7import jax
8import jax.numpy as jnp
9import jax.tree_util as jtu
10import optax
11
12from evorl.replay_buffers import ReplayBuffer
13from evorl.metrics import MetricBase
14from evorl.types import PyTreeDict, State, Params
15from evorl.utils.jax_utils import tree_get, tree_set
16from evorl.evaluators import Evaluator, EpisodeCollector
17from evorl.agent import AgentState
18from evorl.envs import create_env, AutoresetMode
19from evorl.recorders import get_1d_array_statistics, add_prefix
20from evorl.ec.optimizers import SepCEM, ECState, ExponentialScheduleSpec
21from evorl.algorithms.td3 import make_mlp_td3_agent, TD3NetworkParams
22from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
23
24from ..cemrl_workflow import CEMRLTrainMetric
25from .cemrl_td3_workflow import (
26 CEMRLTD3WorkflowTemplate,
27 cemrl_replace_td3_actor_params,
28 create_dummy_td3_trainmetric,
29)
30
31
[docs]
32class EvaluateMetric(MetricBase):
33 pop_center_episode_returns: chex.Array
34 pop_center_episode_lengths: chex.Array
35
36
[docs]
37class CEMRLWorkflow(CEMRLTD3WorkflowTemplate):
38 """1 critic + n actors + 1 replay buffer.
39
40 We use shard_map to split and parallel the population.
41 """
42
[docs]
43 @classmethod
44 def name(cls):
45 return "CEMRL"
46
47 @classmethod
48 def _build_from_config(cls, config: DictConfig) -> Self:
49 assert config.warmup_iters > 0 or config.random_timesteps > 0, (
50 "Either warmup_iters or random_timesteps should be positive to pre-fill some data in the replay buffer"
51 )
52
53 # env for one actor
54 env = create_env(
55 config.env,
56 episode_length=config.env.max_episode_steps,
57 parallel=config.num_envs,
58 autoreset_mode=AutoresetMode.DISABLED,
59 record_ori_obs=True,
60 )
61
62 agent = make_mlp_td3_agent(
63 action_space=env.action_space,
64 norm_layer_type=config.agent_network.norm_layer_type,
65 num_critics=config.agent_network.num_critics,
66 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
67 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
68 discount=config.discount,
69 exploration_epsilon=config.exploration_epsilon,
70 policy_noise=config.policy_noise,
71 clip_policy_noise=config.clip_policy_noise,
72 critics_in_actor_loss=config.critics_in_actor_loss,
73 normalize_obs=config.normalize_obs,
74 )
75
76 if (
77 config.optimizer.grad_clip_norm is not None
78 and config.optimizer.grad_clip_norm > 0
79 ):
80 optimizer = optax.chain(
81 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
82 optax.adam(config.optimizer.lr),
83 )
84 else:
85 optimizer = optax.adam(config.optimizer.lr)
86
87 ec_optimizer = SepCEM(
88 pop_size=config.pop_size,
89 num_elites=config.num_elites,
90 cov_eps_schedule=ExponentialScheduleSpec(**config.cov_eps),
91 weighted_update=config.weighted_update,
92 rank_weight_shift=config.rank_weight_shift,
93 mirror_sampling=config.mirror_sampling,
94 )
95
96 if config.fitness_with_exploration:
97 action_fn = agent.compute_actions
98 else:
99 action_fn = agent.evaluate_actions
100
101 collector = EpisodeCollector(
102 env=env,
103 action_fn=action_fn,
104 max_episode_steps=config.env.max_episode_steps,
105 env_extra_fields=("ori_obs", "termination"),
106 )
107
108 replay_buffer = ReplayBuffer(
109 capacity=config.replay_buffer_capacity,
110 min_sample_timesteps=config.batch_size,
111 sample_batch_size=config.batch_size,
112 )
113
114 # to evaluate the pop-mean actor
115 eval_env = create_env(
116 config.env,
117 episode_length=config.env.max_episode_steps,
118 parallel=config.num_eval_envs,
119 autoreset_mode=AutoresetMode.DISABLED,
120 )
121
122 evaluator = Evaluator(
123 env=eval_env,
124 action_fn=agent.evaluate_actions,
125 max_episode_steps=config.env.max_episode_steps,
126 )
127
128 agent_state_vmap_axes = AgentState(
129 params=TD3NetworkParams(
130 critic_params=None,
131 actor_params=0,
132 target_critic_params=None,
133 target_actor_params=0,
134 ),
135 obs_preprocessor_state=None,
136 )
137
138 workflow = cls(
139 env=env,
140 agent=agent,
141 agent_state_vmap_axes=agent_state_vmap_axes,
142 optimizer=optimizer,
143 ec_optimizer=ec_optimizer,
144 collector=collector,
145 evaluator=evaluator,
146 replay_buffer=replay_buffer,
147 config=config,
148 )
149
150 return workflow
151
152 def _setup_agent_and_optimizer(
153 self, key: chex.PRNGKey
154 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
155 agent_key, ec_key = jax.random.split(key)
156
157 # one actor + one critic
158 agent_state = self.agent.init(
159 self.env.obs_space, self.env.action_space, agent_key
160 )
161
162 init_actor_params = agent_state.params.actor_params
163 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
164
165 agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params=None)
166
167 opt_state = PyTreeDict(
168 # Note: we create and drop the actors' opt_state at every step
169 critic=self.optimizer.init(agent_state.params.critic_params),
170 actor=None,
171 )
172
173 return agent_state, opt_state, ec_opt_state
174
175 def _rl_injection(self, ec_opt_state: ECState, pop: Params) -> ECState:
176 return ec_opt_state.replace(pop=pop)
177
[docs]
178 def step(self, state: State) -> tuple[MetricBase, State]:
179 pop_size = self.config.pop_size
180 agent_state = state.agent_state
181 opt_state = state.opt_state
182 ec_opt_state = state.ec_opt_state
183 replay_buffer_state = state.replay_buffer_state
184 iterations = state.metrics.iterations + 1
185
186 key, perm_key, rollout_key, learn_key = jax.random.split(state.key, num=4)
187
188 # ======= CEM Sample ========
189 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
190
191 # ======== RL update ========
192 # learning_actor_indices = slice(
193 # self.config.num_learning_offspring
194 # ) # [:self.config.num_learning_offspring]
195 learning_actor_indices = jax.random.choice(
196 perm_key,
197 self.config.pop_size,
198 (self.config.num_learning_offspring,),
199 replace=False,
200 )
201
202 def _rl_update(agent_state, opt_state, pop_actor_params):
203 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices)
204 learning_agent_state = cemrl_replace_td3_actor_params(
205 agent_state, learning_actor_params
206 )
207
208 # reset actors' opt_state
209 learning_opt_state = opt_state.replace(
210 actor=self.optimizer.init(learning_actor_params),
211 )
212
213 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update(
214 learning_agent_state,
215 learning_opt_state,
216 replay_buffer_state,
217 learn_key,
218 )
219
220 pop_actor_params = tree_set(
221 pop_actor_params,
222 learning_agent_state.params.actor_params,
223 learning_actor_indices,
224 unique_indices=True,
225 )
226 # drop the actors and their opt_state
227 agent_state = cemrl_replace_td3_actor_params(
228 learning_agent_state, pop_actor_params=None
229 )
230 opt_state = learning_opt_state.replace(actor=None)
231 return td3_metrics, pop_actor_params, agent_state, opt_state
232
233 def _dummy_rl_update(agent_state, opt_state, pop_actor_params):
234 return (
235 create_dummy_td3_trainmetric(self.config.num_learning_offspring),
236 pop_actor_params,
237 agent_state,
238 opt_state,
239 )
240
241 td3_metrics, pop_actor_params, agent_state, opt_state = jax.lax.cond(
242 iterations > self.config.warmup_iters,
243 _rl_update,
244 _dummy_rl_update,
245 agent_state,
246 opt_state,
247 pop_actor_params,
248 )
249
250 # ======== CEM update ========
251 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params)
252
253 # the trajectory [T, #pop*B, ...]
254 # metrics: [#pop, B]
255 eval_metrics, trajectory, replay_buffer_state = self._rollout(
256 pop_agent_state, replay_buffer_state, rollout_key
257 )
258
259 fitnesses = eval_metrics.episode_returns.mean(axis=-1)
260
261 ec_opt_state = self._rl_injection(ec_opt_state, pop_actor_params)
262 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
263
264 # adding debug info for CEM
265 ec_info = PyTreeDict(ec_metrics)
266 ec_info.cov_eps = ec_opt_state.cov_eps
267
268 def _calc_elites_stats(ec_info):
269 elites_indices = jax.lax.top_k(fitnesses, self.config.num_elites)[1]
270 elites_from_rl = jnp.isin(learning_actor_indices, elites_indices).astype(
271 jnp.int32
272 )
273 ec_info.elites_from_rl = elites_from_rl.sum()
274 ec_info.elites_from_rl_ratio = elites_from_rl.mean()
275 return ec_info
276
277 def _dummy_calc_elites_stats(ec_info):
278 ec_info.elites_from_rl = jnp.zeros((), dtype=jnp.int32)
279 ec_info.elites_from_rl_ratio = jnp.zeros(())
280 return ec_info
281
282 ec_info = jax.lax.cond(
283 iterations > self.config.warmup_iters,
284 _calc_elites_stats,
285 _dummy_calc_elites_stats,
286 ec_info,
287 )
288
289 train_metrics = CEMRLTrainMetric(
290 rb_size=replay_buffer_state.buffer_size,
291 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1),
292 pop_episode_returns=eval_metrics.episode_returns.mean(-1),
293 rl_metrics=td3_metrics,
294 ec_info=ec_info,
295 )
296
297 # calculate the number of timestep
298 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32)
299 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
300
301 workflow_metrics = state.metrics.replace(
302 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
303 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
304 iterations=iterations,
305 )
306
307 state = state.replace(
308 key=key,
309 metrics=workflow_metrics,
310 agent_state=agent_state,
311 replay_buffer_state=replay_buffer_state,
312 ec_opt_state=ec_opt_state,
313 opt_state=opt_state,
314 )
315
316 return train_metrics, state
317
[docs]
318 def evaluate(self, state: State) -> tuple[MetricBase, State]:
319 pop_mean_actor_params = state.ec_opt_state.mean
320
321 pop_mean_agent_state = cemrl_replace_td3_actor_params(
322 state.agent_state, pop_mean_actor_params
323 )
324
325 key, eval_key = jax.random.split(state.key, num=2)
326
327 # [#episodes]
328 raw_eval_metrics = self.evaluator.evaluate(
329 pop_mean_agent_state, eval_key, num_episodes=self.config.eval_episodes
330 )
331
332 eval_metrics = EvaluateMetric(
333 pop_center_episode_returns=raw_eval_metrics.episode_returns.mean(),
334 pop_center_episode_lengths=raw_eval_metrics.episode_lengths.mean(),
335 )
336
337 state = state.replace(key=key)
338
339 return eval_metrics, state
340
[docs]
341 def learn(self, state: State) -> State:
342 num_iters = math.ceil(
343 (self.config.total_episodes - state.metrics.sampled_episodes)
344 / (self.config.episodes_for_fitness * self.config.pop_size)
345 )
346
347 final_iteration = num_iters + state.metrics.iterations
348 for i in range(state.metrics.iterations, final_iteration):
349 iters = i + 1
350 train_metrics, state = self.step(state)
351 workflow_metrics = state.metrics
352
353 workflow_metrics_dict = workflow_metrics.to_local_dict()
354 self.recorder.write(workflow_metrics_dict, iters)
355
356 train_metrics_dict = train_metrics.to_local_dict()
357 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
358 train_metrics_dict["pop_episode_returns"], histogram=True
359 )
360
361 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
362 train_metrics_dict["pop_episode_lengths"], histogram=True
363 )
364
365 if train_metrics_dict["rl_metrics"] is not None:
366 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map(
367 get_1d_array_statistics,
368 train_metrics_dict["rl_metrics"]["raw_loss_dict"],
369 )
370
371 self.recorder.write(train_metrics_dict, iters)
372
373 std_statistics = get_std_statistics(state.ec_opt_state.variance["params"])
374 self.recorder.write({"ec/std": std_statistics}, iters)
375
376 if iters % self.config.eval_interval == 0 or iters == final_iteration:
377 eval_metrics, state = self.evaluate(state)
378
379 self.recorder.write(
380 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
381 )
382
383 saved_state = state
384 if not self.config.save_replay_buffer:
385 saved_state = skip_replay_buffer_state(saved_state)
386
387 self.checkpoint_manager.save(
388 iters,
389 saved_state,
390 iters == final_iteration,
391 )
392
393 return state
394
395
[docs]
396def get_std_statistics(variance):
397 def _get_stats(x):
398 x = np.sqrt(x)
399 return dict(
400 min=np.min(x).tolist(),
401 max=np.max(x).tolist(),
402 mean=np.mean(x).tolist(),
403 )
404
405 return jtu.tree_map(_get_stats, variance)