1import math
2from omegaconf import DictConfig
3
4import chex
5import jax
6import jax.numpy as jnp
7import jax.tree_util as jtu
8import optax
9
10from evorl.replay_buffers import ReplayBuffer
11from evorl.metrics import MetricBase
12from evorl.types import State, PyTreeDict
13from evorl.evaluators import Evaluator, EpisodeCollector
14from evorl.agent import AgentState
15from evorl.types import Params
16from evorl.envs import create_env, AutoresetMode
17from evorl.recorders import get_1d_array_statistics, add_prefix
18from evorl.ec.optimizers import OpenES, ExponentialScheduleSpec, ECState
19from evorl.utils.jax_utils import tree_set, tree_get
20from evorl.algorithms.td3 import make_mlp_td3_agent, TD3NetworkParams
21from evorl.algorithms.offpolicy_utils import skip_replay_buffer_state
22
23from .cemrl_td3_workflow import (
24 create_dummy_td3_trainmetric,
25 cemrl_replace_td3_actor_params,
26 CEMRLTD3WorkflowTemplate,
27)
28from ..cemrl_workflow import CEMRLTrainMetric
29
30
[docs]
31class CEMRLOpenESWorkflow(CEMRLTD3WorkflowTemplate):
32 """1 critic + n actors + 1 replay buffer.
33
34 We use shard_map to split and parallel the population.
35 """
36
[docs]
37 @classmethod
38 def name(cls):
39 return "CEMRL-OpenES"
40
41 @classmethod
42 def _build_from_config(cls, config: DictConfig):
43 assert config.warmup_iters > 0 or config.random_timesteps > 0, (
44 "Either warmup_iters or random_timesteps should be positive to pre-fill some data in the replay buffer"
45 )
46
47 # env for one actor
48 env = create_env(
49 config.env,
50 episode_length=config.env.max_episode_steps,
51 parallel=config.num_envs,
52 autoreset_mode=AutoresetMode.DISABLED,
53 record_ori_obs=True,
54 )
55
56 agent = make_mlp_td3_agent(
57 action_space=env.action_space,
58 norm_layer_type=config.agent_network.norm_layer_type,
59 num_critics=config.agent_network.num_critics,
60 critic_hidden_layer_sizes=config.agent_network.critic_hidden_layer_sizes,
61 actor_hidden_layer_sizes=config.agent_network.actor_hidden_layer_sizes,
62 discount=config.discount,
63 exploration_epsilon=config.exploration_epsilon,
64 policy_noise=config.policy_noise,
65 clip_policy_noise=config.clip_policy_noise,
66 critics_in_actor_loss=config.critics_in_actor_loss,
67 normalize_obs=config.normalize_obs,
68 )
69
70 if (
71 config.optimizer.grad_clip_norm is not None
72 and config.optimizer.grad_clip_norm > 0
73 ):
74 optimizer = optax.chain(
75 optax.clip_by_global_norm(config.optimizer.grad_clip_norm),
76 optax.adam(config.optimizer.lr),
77 )
78 else:
79 optimizer = optax.adam(config.optimizer.lr)
80
81 ec_optimizer = OpenES(
82 pop_size=config.pop_size,
83 lr_schedule=ExponentialScheduleSpec(**config.ec_lr),
84 noise_std_schedule=ExponentialScheduleSpec(**config.ec_noise_std),
85 mirror_sampling=config.mirror_sampling,
86 )
87
88 if config.fitness_with_exploration:
89 action_fn = agent.compute_actions
90 else:
91 action_fn = agent.evaluate_actions
92
93 collector = EpisodeCollector(
94 env=env,
95 action_fn=action_fn,
96 max_episode_steps=config.env.max_episode_steps,
97 env_extra_fields=("ori_obs", "termination"),
98 )
99
100 replay_buffer = ReplayBuffer(
101 capacity=config.replay_buffer_capacity,
102 min_sample_timesteps=config.batch_size,
103 sample_batch_size=config.batch_size,
104 )
105
106 # to evaluate the pop-mean actor
107 eval_env = create_env(
108 config.env,
109 episode_length=config.env.max_episode_steps,
110 parallel=config.num_eval_envs,
111 autoreset_mode=AutoresetMode.DISABLED,
112 )
113
114 evaluator = Evaluator(
115 env=eval_env,
116 action_fn=agent.evaluate_actions,
117 max_episode_steps=config.env.max_episode_steps,
118 )
119
120 agent_state_vmap_axes = AgentState(
121 params=TD3NetworkParams(
122 critic_params=None,
123 actor_params=0,
124 target_critic_params=None,
125 target_actor_params=0,
126 ),
127 obs_preprocessor_state=None,
128 )
129
130 workflow = cls(
131 env=env,
132 agent=agent,
133 agent_state_vmap_axes=agent_state_vmap_axes,
134 optimizer=optimizer,
135 ec_optimizer=ec_optimizer,
136 collector=collector,
137 evaluator=evaluator,
138 replay_buffer=replay_buffer,
139 config=config,
140 )
141
142 return workflow
143
144 def _setup_agent_and_optimizer(
145 self, key: chex.PRNGKey
146 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
147 agent_key, ec_key = jax.random.split(key)
148
149 # one actor + one critic
150 agent_state = self.agent.init(
151 self.env.obs_space, self.env.action_space, agent_key
152 )
153
154 init_actor_params = agent_state.params.actor_params
155 ec_opt_state = self.ec_optimizer.init(init_actor_params, ec_key)
156
157 agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params=None)
158
159 opt_state = PyTreeDict(
160 # Note: we create and drop the actors' opt_state at every step
161 critic=self.optimizer.init(agent_state.params.critic_params),
162 actor=None,
163 )
164
165 return agent_state, opt_state, ec_opt_state
166
167 def _rl_injection(
168 self, ec_opt_state: ECState, pop: Params, external_indices
169 ) -> ECState:
170 external_noise = jtu.tree_map(
171 lambda x, m: (x - m) / ec_opt_state.noise_std,
172 tree_get(pop, external_indices),
173 ec_opt_state.mean,
174 )
175 noise = tree_set(
176 ec_opt_state.noise,
177 external_noise,
178 external_indices,
179 unique_indices=True,
180 )
181
182 return ec_opt_state.replace(noise=noise)
183
[docs]
184 def step(self, state: State) -> tuple[MetricBase, State]:
185 pop_size = self.config.pop_size
186 agent_state = state.agent_state
187 opt_state = state.opt_state
188 ec_opt_state = state.ec_opt_state
189 replay_buffer_state = state.replay_buffer_state
190 iterations = state.metrics.iterations + 1
191
192 pop_actor_params = agent_state.params.actor_params
193
194 key, rollout_key, perm_key, learn_key = jax.random.split(state.key, num=4)
195
196 # ======= CEM Sample ========
197 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
198
199 # ======== RL update ========
200 learning_actor_indices = jax.random.choice(
201 perm_key,
202 self.config.pop_size,
203 (self.config.num_learning_offspring,),
204 replace=False,
205 )
206
207 def _rl_update(agent_state, opt_state, pop_actor_params):
208 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices)
209 learning_agent_state = cemrl_replace_td3_actor_params(
210 agent_state, learning_actor_params
211 )
212
213 # reset actors' opt_state
214 learning_opt_state = opt_state.replace(
215 actor=self.optimizer.init(learning_actor_params),
216 )
217
218 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update(
219 learning_agent_state,
220 learning_opt_state,
221 replay_buffer_state,
222 learn_key,
223 )
224
225 pop_actor_params = tree_set(
226 pop_actor_params,
227 learning_agent_state.params.actor_params,
228 learning_actor_indices,
229 unique_indices=True,
230 )
231 # drop the actors and their opt_state
232 agent_state = cemrl_replace_td3_actor_params(
233 learning_agent_state, pop_actor_params=None
234 )
235 opt_state = learning_opt_state.replace(actor=None)
236 return td3_metrics, pop_actor_params, agent_state, opt_state
237
238 def _dummy_rl_update(agent_state, opt_state, pop_actor_params):
239 return (
240 create_dummy_td3_trainmetric(self.config.num_learning_offspring),
241 pop_actor_params,
242 agent_state,
243 opt_state,
244 )
245
246 td3_metrics, pop_actor_params, agent_state, opt_state = jax.lax.cond(
247 iterations > self.config.warmup_iters,
248 _rl_update,
249 _dummy_rl_update,
250 agent_state,
251 opt_state,
252 pop_actor_params,
253 )
254
255 # ======== CEM update ========
256 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params)
257
258 # the trajectory [T, #pop*B, ...]
259 # metrics: [#pop, B]
260 eval_metrics, trajectory, replay_buffer_state = self._rollout(
261 pop_agent_state, replay_buffer_state, rollout_key
262 )
263
264 fitnesses = eval_metrics.episode_returns.mean(axis=-1)
265
266 ec_opt_state = jax.lax.cond(
267 iterations > self.config.warmup_iters,
268 self._rl_injection,
269 lambda ec_opt_state, pop, external_indices: ec_opt_state,
270 ec_opt_state,
271 pop_actor_params,
272 learning_actor_indices,
273 )
274 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
275
276 train_metrics = CEMRLTrainMetric(
277 rb_size=replay_buffer_state.buffer_size,
278 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1),
279 pop_episode_returns=eval_metrics.episode_returns.mean(-1),
280 rl_metrics=td3_metrics,
281 ec_info=ec_metrics,
282 )
283
284 # calculate the number of timestep
285 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32)
286 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
287
288 workflow_metrics = state.metrics.replace(
289 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
290 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
291 iterations=iterations,
292 )
293
294 state = state.replace(
295 key=key,
296 metrics=workflow_metrics,
297 agent_state=agent_state,
298 replay_buffer_state=replay_buffer_state,
299 ec_opt_state=ec_opt_state,
300 opt_state=opt_state,
301 )
302
303 return train_metrics, state
304
[docs]
305 def learn(self, state: State) -> State:
306 num_iters = math.ceil(
307 (self.config.total_episodes - state.metrics.sampled_episodes)
308 / (self.config.episodes_for_fitness * self.config.pop_size)
309 )
310
311 final_iteration = num_iters + state.metrics.iterations
312 for i in range(state.metrics.iterations, final_iteration):
313 iters = i + 1
314 train_metrics, state = self.step(state)
315 workflow_metrics = state.metrics
316
317 workflow_metrics_dict = workflow_metrics.to_local_dict()
318 self.recorder.write(workflow_metrics_dict, iters)
319
320 train_metrics_dict = train_metrics.to_local_dict()
321 train_metrics_dict["pop_episode_returns"] = get_1d_array_statistics(
322 train_metrics_dict["pop_episode_returns"], histogram=True
323 )
324
325 train_metrics_dict["pop_episode_lengths"] = get_1d_array_statistics(
326 train_metrics_dict["pop_episode_lengths"], histogram=True
327 )
328
329 if train_metrics_dict["rl_metrics"] is not None:
330 train_metrics_dict["rl_metrics"]["raw_loss_dict"] = jtu.tree_map(
331 get_1d_array_statistics,
332 train_metrics_dict["rl_metrics"]["raw_loss_dict"],
333 )
334
335 self.recorder.write(train_metrics_dict, iters)
336
337 if iters % self.config.eval_interval == 0 or iters == final_iteration:
338 eval_metrics, state = self.evaluate(state)
339
340 self.recorder.write(
341 add_prefix(eval_metrics.to_local_dict(), "eval"), iters
342 )
343
344 saved_state = state
345 if not self.config.save_replay_buffer:
346 saved_state = skip_replay_buffer_state(saved_state)
347
348 self.checkpoint_manager.save(
349 iters,
350 saved_state,
351 force=iters == final_iteration,
352 )
353
354 return state