1import logging
2import math
3from omegaconf import DictConfig
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10
11from evorl.replay_buffers import ReplayBuffer
12from evorl.metrics import MetricBase
13from evorl.types import PyTreeDict, State
14from evorl.utils.jax_utils import tree_get
15from evorl.evaluators import Evaluator, EpisodeCollector
16from evorl.agent import AgentState
17from evorl.envs import create_env, AutoresetMode
18from evorl.recorders import get_1d_array_statistics, add_prefix, get_1d_array
19from evorl.ec.optimizers import ECState, VanillaESMod, ExponentialScheduleSpec
20from evorl.algorithms.td3 import make_mlp_td3_agent
21from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
22
23from ..erl_workflow import ERLTrainMetric
24from .erl_td3_workflow import ERLTD3WorkflowTemplate, erl_replace_td3_actor_params
25
26logger = logging.getLogger(__name__)
27
28
[docs]
29class EvaluateMetric(MetricBase):
30 rl_episode_returns: chex.Array
31 rl_episode_lengths: chex.Array
32 pop_center_episode_returns: chex.Array
33 pop_center_episode_lengths: chex.Array
34
35
[docs]
36class ERLESWorkflow(ERLTD3WorkflowTemplate):
37 """ERL w/ ES.
38
39 Configs:
40
41 - EC: n actors
42 - RL: k actors + k critics
43 - Shared replay buffer
44 """
45
[docs]
46 @classmethod
47 def name(cls):
48 return "ERL-ES"
49
50 @classmethod
51 def _build_from_config(cls, config: DictConfig):
52 # env for rl&ec rollout
53 env = create_env(
54 config.env,
55 episode_length=config.env.max_episode_steps,
56 parallel=config.num_envs,
57 autoreset_mode=AutoresetMode.DISABLED,
58 record_ori_obs=True,
59 )
60
61 agent = make_mlp_td3_agent(
62 action_space=env.action_space,
63 norm_layer_type=config.agent_network.norm_layer_type,
64 num_critics=config.agent_network.num_critics,
65 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
66 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
67 discount=config.discount,
68 exploration_epsilon=config.exploration_epsilon,
69 policy_noise=config.policy_noise,
70 clip_policy_noise=config.clip_policy_noise,
71 critics_in_actor_loss=config.critics_in_actor_loss,
72 normalize_obs=config.normalize_obs,
73 )
74
75 if (
76 config.optimizer.grad_clip_norm is not None
77 and config.optimizer.grad_clip_norm > 0
78 ):
79 optimizer = optax.chain(
80 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
81 optax.adam(config.optimizer.lr),
82 )
83 else:
84 optimizer = optax.adam(config.optimizer.lr)
85
86 ec_optimizer = VanillaESMod(
87 pop_size=config.pop_size,
88 external_size=config.num_rl_agents,
89 num_elites=config.num_elites,
90 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
91 mix_strategy=config.mix_strategy,
92 )
93
94 if config.fitness_with_exploration:
95 action_fn = agent.compute_actions
96 else:
97 action_fn = agent.evaluate_actions
98
99 ec_collector = EpisodeCollector(
100 env=env,
101 action_fn=action_fn,
102 max_episode_steps=config.env.max_episode_steps,
103 env_extra_fields=("ori_obs", "termination"),
104 )
105
106 if config.rl_exploration:
107 action_fn = agent.compute_actions
108 else:
109 action_fn = agent.evaluate_actions
110
111 rl_collector = EpisodeCollector(
112 env=env,
113 action_fn=action_fn,
114 max_episode_steps=config.env.max_episode_steps,
115 env_extra_fields=("ori_obs", "termination"),
116 )
117
118 replay_buffer = ReplayBuffer(
119 capacity=config.replay_buffer_capacity,
120 min_sample_timesteps=config.batch_size,
121 sample_batch_size=config.batch_size,
122 )
123
124 # to evaluate the pop-mean actor
125 eval_env = create_env(
126 config.env,
127 episode_length=config.env.max_episode_steps,
128 parallel=config.num_eval_envs,
129 autoreset_mode=AutoresetMode.DISABLED,
130 )
131
132 evaluator = Evaluator(
133 env=eval_env,
134 action_fn=agent.evaluate_actions,
135 max_episode_steps=config.env.max_episode_steps,
136 )
137
138 agent_state_vmap_axes = AgentState(
139 params=0,
140 obs_preprocessor_state=None,
141 )
142
143 workflow = cls(
144 env=env,
145 agent=agent,
146 agent_state_vmap_axes=agent_state_vmap_axes,
147 optimizer=optimizer,
148 ec_optimizer=ec_optimizer,
149 ec_collector=ec_collector,
150 rl_collector=rl_collector,
151 evaluator=evaluator,
152 replay_buffer=replay_buffer,
153 config=config,
154 )
155
156 return workflow
157
158 def _setup_agent_and_optimizer(
159 self, key: chex.PRNGKey
160 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
161 agent_key, pop_agent_key, ec_key = jax.random.split(key, 3)
162
163 # agent for RL
164 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))(
165 self.env.obs_space,
166 self.env.action_space,
167 jax.random.split(agent_key, self.config.num_rl_agents),
168 )
169
170 # all agents will share the same obs_preprocessor_state
171 if agent_state.obs_preprocessor_state is not None:
172 agent_state = agent_state.replace(
173 obs_preprocessor_state=tree_get(agent_state.obs_preprocessor_state, 0)
174 )
175
176 dummy_obs = self.env.obs_space.sample(key)
177 init_actor_params = self.agent.actor_network.init(
178 pop_agent_key, jtu.tree_map(lambda x: x[None, ...], dummy_obs)
179 )
180
181 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
182
183 opt_state = PyTreeDict(
184 actor=self.optimizer.init(agent_state.params.actor_params),
185 critic=self.optimizer.init(agent_state.params.critic_params),
186 )
187
188 return agent_state, opt_state, ec_opt_state
189
190 def _rl_injection(
191 self,
192 ec_opt_state: ECState,
193 agent_state: AgentState,
194 ec_fitnesses: chex.Array,
195 rl_fitnesses: chex.Array,
196 ) -> tuple[chex.Array, ECState]:
197 rl_noise = jtu.tree_map(
198 lambda x, m: x - m,
199 agent_state.params.actor_params,
200 ec_opt_state.mean,
201 )
202
203 concat_noise = jtu.tree_map(
204 lambda n1, n2: jnp.concatenate([n1, n2], axis=0),
205 ec_opt_state.noise,
206 rl_noise,
207 )
208
209 ec_opt_state = ec_opt_state.replace(noise=concat_noise)
210
211 fitnesses = jnp.concatenate([ec_fitnesses, rl_fitnesses], axis=0)
212
213 return fitnesses, ec_opt_state
214
[docs]
215 def step(self, state: State) -> tuple[MetricBase, State]:
216 pop_size = self.config.pop_size
217 agent_state = state.agent_state
218 opt_state = state.opt_state
219 ec_opt_state = state.ec_opt_state
220 replay_buffer_state = state.replay_buffer_state
221 iterations = state.metrics.iterations + 1
222
223 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split(
224 state.key, num=4
225 )
226
227 # ======== EC & RL rollout ========
228 # the trajectory [#pop, T, B, ...]
229 # metrics: [#pop, B]
230 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
231
232 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params)
233 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout(
234 pop_agent_state, replay_buffer_state, ec_rollout_key
235 )
236
237 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
238 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
239
240 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout(
241 agent_state, replay_buffer_state, rl_rollout_key
242 )
243
244 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
245 rl_sampled_episodes = jnp.uint32(
246 self.config.num_rl_agents * self.config.rollout_episodes
247 )
248
249 train_metrics = ERLTrainMetric(
250 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1),
251 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1),
252 )
253
254 # ======== RL update ========
255 td3_metrics, agent_state, opt_state = self._rl_update(
256 agent_state, opt_state, replay_buffer_state, learn_key
257 )
258
259 # get average loss
260 td3_metrics = td3_metrics.replace(
261 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents,
262 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents,
263 )
264
265 train_metrics = train_metrics.replace(
266 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1),
267 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1),
268 rl_metrics=td3_metrics,
269 )
270
271 # ======== EC update ========
272 # inject RL into EC
273
274 ec_fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1)
275 rl_fitnesses = rl_eval_metrics.episode_returns.mean(axis=-1)
276 fitnesses, ec_opt_state = self._rl_injection(
277 ec_opt_state, agent_state, ec_fitnesses, rl_fitnesses
278 )
279
280 ec_metrics, ec_opt_state = self.ec_optimizer.tell_external(
281 ec_opt_state, fitnesses
282 )
283
284 train_metrics = train_metrics.replace(
285 ec_info=ec_metrics, rb_size=replay_buffer_state.buffer_size
286 )
287
288 # calculate the number of timestep
289 sampled_timesteps = ec_sampled_timesteps + rl_sampled_timesteps
290 sampled_episodes = ec_sampled_episodes + rl_sampled_episodes
291 workflow_metrics = state.metrics.replace(
292 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
293 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
294 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps
295 + rl_sampled_timesteps,
296 iterations=iterations,
297 )
298
299 state = state.replace(
300 key=key,
301 metrics=workflow_metrics,
302 agent_state=agent_state,
303 replay_buffer_state=replay_buffer_state,
304 ec_opt_state=ec_opt_state,
305 opt_state=opt_state,
306 )
307
308 return train_metrics, state
309
[docs]
310 def evaluate(self, state: State) -> tuple[MetricBase, State]:
311 key, rl_eval_key, ec_eval_key = jax.random.split(state.key, num=3)
312
313 rl_eval_metrics = jax.vmap(
314 self.evaluator.evaluate, in_axes=(self.agent_state_vmap_axes, 0, None)
315 )(
316 state.agent_state,
317 jax.random.split(rl_eval_key, num=self.config.num_rl_agents),
318 self.config.eval_episodes,
319 )
320
321 pop_mean_actor_params = state.ec_opt_state.mean
322
323 pop_mean_agent_state = erl_replace_td3_actor_params(
324 state.agent_state, pop_mean_actor_params
325 )
326
327 ec_eval_metrics = self.evaluator.evaluate(
328 pop_mean_agent_state, ec_eval_key, num_episodes=self.config.eval_episodes
329 )
330
331 eval_metrics = EvaluateMetric(
332 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1),
333 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1),
334 pop_center_episode_returns=ec_eval_metrics.episode_returns.mean(),
335 pop_center_episode_lengths=ec_eval_metrics.episode_lengths.mean(),
336 )
337
338 state = state.replace(key=key)
339
340 return eval_metrics, state
341
[docs]
342 def learn(self, state: State) -> State:
343 sampled_episodes_per_iter = (
344 self.config.episodes_for_fitness * self.config.pop_size
345 + self.config.rollout_episodes * self.config.num_rl_agents
346 )
347 num_iters = math.ceil(
348 (self.config.total_episodes - state.metrics.sampled_episodes)
349 / sampled_episodes_per_iter
350 )
351
352 final_iteration = num_iters + state.metrics.iterations
353 for i in range(state.metrics.iterations, final_iteration):
354 iters = i + 1
355 train_metrics, state = self.step(state)
356 workflow_metrics = state.metrics
357
358 workflow_metrics_dict = workflow_metrics.to_local_dict()
359 self.recorder.write(workflow_metrics_dict, iters)
360
361 train_metrics_dict = train_metrics.to_local_dict()
362 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
363 train_metrics_dict["pop_episode_returns"], histogram=True
364 )
365
366 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
367 train_metrics_dict["pop_episode_lengths"], histogram=True
368 )
369
370 if self.config.num_rl_agents > 1:
371 train_metrics_dict["rl_episode_lengths"] = get_1d_array_statistics(
372 train_metrics_dict["rl_episode_lengths"], histogram=True
373 )
374 train_metrics_dict["rl_episode_returns"] = get_1d_array_statistics(
375 train_metrics_dict["rl_episode_returns"], histogram=True
376 )
377 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map(
378 get_1d_array_statistics,
379 train_metrics_dict["rl_metrics"]["raw_loss_dict"],
380 )
381 else:
382 train_metrics_dict["rl_episode_lengths"] = train_metrics_dict[
383 "rl_episode_lengths"
384 ].squeeze(0)
385 train_metrics_dict["rl_episode_returns"] = train_metrics_dict[
386 "rl_episode_returns"
387 ].squeeze(0)
388
389 self.recorder.write(train_metrics_dict, iters)
390
391 if iters % self.config.eval_interval == 0 or iters == final_iteration:
392 eval_metrics, state = self.evaluate(state)
393
394 eval_metrics_dict = eval_metrics.to_local_dict()
395 if self.config.num_rl_agents > 1:
396 eval_metrics_dict = jtu.tree_map(get_1d_array, eval_metrics_dict)
397 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
398
399 saved_state = state
400 if not self.config.save_replay_buffer:
401 saved_state = skip_replay_buffer_state(saved_state)
402
403 self.checkpoint_manager.save(
404 iters,
405 saved_state,
406 force=iters == final_iteration,
407 )
408
409 return state