1import logging
2import math
3from omegaconf import DictConfig
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9import optax
10
11from evorl.replay_buffers import ReplayBuffer
12from evorl.metrics import MetricBase
13from evorl.types import PyTreeDict, State
14from evorl.evaluators import Evaluator, EpisodeCollector
15from evorl.agent import AgentState
16from evorl.envs import create_env, AutoresetMode
17from evorl.recorders import get_1d_array_statistics, add_prefix, get_1d_array
18from evorl.ec.optimizers import ERLGAMod, ECState
19from evorl.algorithms.td3 import make_mlp_td3_agent
20from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
21
22from ..erl_workflow import ERLTrainMetric
23from .erl_td3_workflow import ERLTD3WorkflowTemplate, erl_replace_td3_actor_params
24
25logger = logging.getLogger(__name__)
26
27
[docs]
28class EvaluateMetric(MetricBase):
29 rl_episode_returns: chex.Array
30 rl_episode_lengths: chex.Array
31
32
[docs]
33class ERLGAWorkflow(ERLTD3WorkflowTemplate):
34 """ERL w/ GA.
35
36 Config:
37
38 - EC: n actors
39 - RL: k actors + k critics
40 - Shared replay buffer
41 """
42
[docs]
43 @classmethod
44 def name(cls):
45 return "ERL-GA"
46
47 @classmethod
48 def _build_from_config(cls, config: DictConfig):
49 assert config.num_elites >= config.num_rl_agents
50
51 # env for rl&ec rollout
52 env = create_env(
53 config.env,
54 episode_length=config.env.max_episode_steps,
55 parallel=config.num_envs,
56 autoreset_mode=AutoresetMode.DISABLED,
57 record_ori_obs=True,
58 )
59
60 agent = make_mlp_td3_agent(
61 action_space=env.action_space,
62 norm_layer_type=config.agent_network.norm_layer_type,
63 num_critics=config.agent_network.num_critics,
64 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
65 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
66 discount=config.discount,
67 exploration_epsilon=config.exploration_epsilon,
68 policy_noise=config.policy_noise,
69 clip_policy_noise=config.clip_policy_noise,
70 critics_in_actor_loss=config.critics_in_actor_loss,
71 normalize_obs=config.normalize_obs,
72 )
73
74 if (
75 config.optimizer.grad_clip_norm is not None
76 and config.optimizer.grad_clip_norm > 0
77 ):
78 optimizer = optax.chain(
79 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
80 optax.adam(config.optimizer.lr),
81 )
82 else:
83 optimizer = optax.adam(config.optimizer.lr)
84
85 ec_optimizer = ERLGAMod(
86 pop_size=config.pop_size,
87 num_elites=config.num_elites,
88 external_size=config.num_rl_agents,
89 weight_max_magnitude=config.weight_max_magnitude,
90 mut_strength=config.mut_strength,
91 num_mutation_frac=config.num_mutation_frac,
92 super_mut_strength=config.super_mut_strength,
93 super_mut_prob=config.super_mut_prob,
94 reset_prob=config.reset_prob,
95 vec_relative_prob=config.vec_relative_prob,
96 enable_crossover=config.enable_crossover,
97 num_crossover_frac=config.num_crossover_frac,
98 )
99
100 if config.fitness_with_exploration:
101 action_fn = agent.compute_actions
102 else:
103 action_fn = agent.evaluate_actions
104
105 ec_collector = EpisodeCollector(
106 env=env,
107 action_fn=action_fn,
108 max_episode_steps=config.env.max_episode_steps,
109 env_extra_fields=("ori_obs", "termination"),
110 )
111
112 if config.rl_exploration:
113 action_fn = agent.compute_actions
114 else:
115 action_fn = agent.evaluate_actions
116
117 rl_collector = EpisodeCollector(
118 env=env,
119 action_fn=action_fn,
120 max_episode_steps=config.env.max_episode_steps,
121 env_extra_fields=("ori_obs", "termination"),
122 )
123
124 replay_buffer = ReplayBuffer(
125 capacity=config.replay_buffer_capacity,
126 min_sample_timesteps=config.batch_size,
127 sample_batch_size=config.batch_size,
128 )
129
130 # to evaluate the pop-mean actor
131 eval_env = create_env(
132 config.env,
133 episode_length=config.env.max_episode_steps,
134 parallel=config.num_eval_envs,
135 autoreset_mode=AutoresetMode.DISABLED,
136 )
137
138 evaluator = Evaluator(
139 env=eval_env,
140 action_fn=agent.evaluate_actions,
141 max_episode_steps=config.env.max_episode_steps,
142 )
143
144 agent_state_vmap_axes = AgentState(
145 params=0,
146 obs_preprocessor_state=None,
147 )
148
149 workflow = cls(
150 env=env,
151 agent=agent,
152 agent_state_vmap_axes=agent_state_vmap_axes,
153 optimizer=optimizer,
154 ec_optimizer=ec_optimizer,
155 ec_collector=ec_collector,
156 rl_collector=rl_collector,
157 evaluator=evaluator,
158 replay_buffer=replay_buffer,
159 config=config,
160 )
161
162 return workflow
163
164 def _setup_agent_and_optimizer(
165 self, key: chex.PRNGKey
166 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
167 agent_key, pop_agent_key, ec_key = jax.random.split(key, 3)
168
169 # agent for RL
170 agent_state = jax.vmap(self.agent.init, in_axes=(None, None, 0))(
171 self.env.obs_space,
172 self.env.action_space,
173 jax.random.split(agent_key, self.config.num_rl_agents),
174 )
175
176 # all agents will share the same obs_preprocessor_state
177 if agent_state.obs_preprocessor_state is not None:
178 agent_state = agent_state.replace(
179 obs_preprocessor_state=jtu.tree_map(
180 lambda x: x[0], agent_state.obs_preprocessor_state
181 )
182 )
183
184 dummy_obs = self.env.obs_space.sample(key)
185 pop_actor_params = jax.vmap(self.agent.actor_network.init, in_axes=(0, None))(
186 jax.random.split(pop_agent_key, self.config.pop_size),
187 jtu.tree_map(lambda x: x[None, ...], dummy_obs),
188 )
189
190 ec_opt_state = self.ec_optimizer.init(pop_actor_params, ec_key)
191
192 opt_state = PyTreeDict(
193 actor=self.optimizer.init(agent_state.params.actor_params),
194 critic=self.optimizer.init(agent_state.params.critic_params),
195 )
196
197 return agent_state, opt_state, ec_opt_state
198
199 def _rl_injection(self, ec_opt_state: ECState, agent_state: AgentState) -> ECState:
200 rl_actor_params = agent_state.params.actor_params
201 chex.assert_tree_shape_prefix(rl_actor_params, (self.config.num_rl_agents,))
202
203 ec_opt_state = ec_opt_state.replace(external_pop=rl_actor_params)
204
205 return ec_opt_state
206
[docs]
207 def step(self, state: State) -> tuple[MetricBase, State]:
208 """The basic step function for the workflow to update agent."""
209 pop_size = self.config.pop_size
210 agent_state = state.agent_state
211 opt_state = state.opt_state
212 ec_opt_state = state.ec_opt_state
213 replay_buffer_state = state.replay_buffer_state
214 iterations = state.metrics.iterations + 1
215
216 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split(
217 state.key, num=4
218 )
219
220 # ======== EC rollout ========
221 # the trajectory [#pop, T, B, ...]
222 # metrics: [#pop, B]
223 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
224 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params)
225 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout(
226 pop_agent_state, replay_buffer_state, ec_rollout_key
227 )
228
229 ec_sampled_timesteps = ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
230 ec_sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
231
232 # ======== RL update ========
233
234 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout(
235 agent_state, replay_buffer_state, rl_rollout_key
236 )
237
238 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
239 rl_sampled_episodes = jnp.uint32(
240 self.config.num_rl_agents * self.config.rollout_episodes
241 )
242
243 td3_metrics, agent_state, opt_state = self._rl_update(
244 agent_state, opt_state, replay_buffer_state, learn_key
245 )
246
247 # get average loss
248 td3_metrics = td3_metrics.replace(
249 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents,
250 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents,
251 )
252
253 # ====== EC update ======
254 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1)
255
256 def _tell_with_rl_injection(ec_opt_state, fitnesses):
257 ec_opt_state = self._rl_injection(ec_opt_state, agent_state)
258 ec_metrics, ec_opt_state = self.ec_optimizer.tell_external(
259 ec_opt_state, fitnesses
260 )
261 return ec_metrics, ec_opt_state
262
263 ec_metrics, ec_opt_state = jax.lax.cond(
264 iterations % self.config.rl_injection_interval == 0,
265 _tell_with_rl_injection,
266 self.ec_optimizer.tell,
267 ec_opt_state,
268 fitnesses,
269 )
270
271 train_metrics = ERLTrainMetric(
272 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1),
273 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1),
274 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1),
275 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1),
276 rl_metrics=td3_metrics,
277 ec_info=ec_metrics,
278 rb_size=replay_buffer_state.buffer_size,
279 )
280
281 # calculate the number of timestep
282 sampled_timesteps = ec_sampled_timesteps + rl_sampled_timesteps
283 sampled_episodes = ec_sampled_episodes + rl_sampled_episodes
284 workflow_metrics = state.metrics.replace(
285 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
286 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
287 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps
288 + rl_sampled_timesteps,
289 iterations=iterations,
290 )
291
292 state = state.replace(
293 key=key,
294 metrics=workflow_metrics,
295 agent_state=agent_state,
296 replay_buffer_state=replay_buffer_state,
297 ec_opt_state=ec_opt_state,
298 opt_state=opt_state,
299 )
300
301 return train_metrics, state
302
[docs]
303 def evaluate(self, state: State) -> tuple[MetricBase, State]:
304 key, eval_key = jax.random.split(state.key, num=2)
305
306 # [num_rl_agents, #episodes]
307 raw_eval_metrics = jax.vmap(
308 self.evaluator.evaluate, in_axes=(self.agent_state_vmap_axes, 0, None)
309 )(
310 state.agent_state,
311 jax.random.split(eval_key, self.config.num_rl_agents),
312 self.config.eval_episodes,
313 )
314
315 eval_metrics = EvaluateMetric(
316 rl_episode_returns=raw_eval_metrics.episode_returns.mean(-1),
317 rl_episode_lengths=raw_eval_metrics.episode_lengths.mean(-1),
318 )
319
320 state = state.replace(key=key)
321 return eval_metrics, state
322
[docs]
323 def learn(self, state: State) -> State:
324 sampled_episodes_per_iter = (
325 self.config.episodes_for_fitness * self.config.pop_size
326 + self.config.rollout_episodes * self.config.num_rl_agents
327 )
328 num_iters = math.ceil(
329 (self.config.total_episodes - state.metrics.sampled_episodes)
330 / sampled_episodes_per_iter
331 )
332
333 final_iteration = num_iters + state.metrics.iterations
334 for i in range(state.metrics.iterations, final_iteration):
335 iters = i + 1
336 train_metrics, state = self.step(state)
337 workflow_metrics = state.metrics
338
339 workflow_metrics_dict = workflow_metrics.to_local_dict()
340 self.recorder.write(workflow_metrics_dict, iters)
341
342 train_metrics_dict = train_metrics.to_local_dict()
343 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
344 train_metrics_dict["pop_episode_returns"], histogram=True
345 )
346
347 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
348 train_metrics_dict["pop_episode_lengths"], histogram=True
349 )
350
351 if self.config.num_rl_agents > 1:
352 train_metrics_dict["rl_episode_lengths"] = get_1d_array_statistics(
353 train_metrics_dict["rl_episode_lengths"], histogram=True
354 )
355 train_metrics_dict["rl_episode_returns"] = get_1d_array_statistics(
356 train_metrics_dict["rl_episode_returns"], histogram=True
357 )
358 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map(
359 get_1d_array_statistics,
360 train_metrics_dict["rl_metrics"]["raw_loss_dict"],
361 )
362 else:
363 train_metrics_dict["rl_episode_lengths"] = train_metrics_dict[
364 "rl_episode_lengths"
365 ].squeeze(0)
366 train_metrics_dict["rl_episode_returns"] = train_metrics_dict[
367 "rl_episode_returns"
368 ].squeeze(0)
369
370 self.recorder.write(train_metrics_dict, iters)
371
372 if iters % self.config.eval_interval == 0 or iters == final_iteration:
373 eval_metrics, state = self.evaluate(state)
374
375 eval_metrics_dict = eval_metrics.to_local_dict()
376 if self.config.num_rl_agents > 1:
377 eval_metrics_dict = jtu.tree_map(get_1d_array, eval_metrics_dict)
378
379 self.recorder.write(add_prefix(eval_metrics_dict, "eval"), iters)
380
381 saved_state = state
382 if not self.config.save_replay_buffer:
383 saved_state = skip_replay_buffer_state(saved_state)
384
385 self.checkpoint_manager.save(
386 iters,
387 saved_state,
388 force=iters == final_iteration,
389 )
390
391 return state