1import time
2from omegaconf import DictConfig
3
4import chex
5import jax
6import jax.numpy as jnp
7import jax.tree_util as jtu
8
9from evorl.metrics import MetricBase
10from evorl.types import PyTreeDict, State
11from evorl.utils.jax_utils import (
12 tree_get,
13 tree_set,
14 scan_and_last,
15 is_jitted,
16)
17from evorl.algorithms.td3 import TD3TrainMetric
18
19from ..cemrl_workflow import CEMRLTrainMetric as CEMRLTrainMetricBase
20from .cemrl_td3_workflow import cemrl_replace_td3_actor_params
21from .cemrl import CEMRLWorkflow as _CEMRLWorkflow
22
23
[docs]
24class CEMRLTrainMetric(CEMRLTrainMetricBase):
25 num_updates_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32)
26 time_cost_per_iter: float = 0.0
27
28
[docs]
29class WorkflowMetric(MetricBase):
30 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32)
31 sampled_timesteps_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32)
32 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
33 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
34
35
[docs]
36class CEMRLWorkflow(_CEMRLWorkflow):
37 """Original CEMRL impl.
38
39 1 critic + n actors + 1 replay buffer.
40 """
41
[docs]
42 @classmethod
43 def name(cls):
44 return "CEMRL-Origin"
45
46 def _setup_workflow_metrics(self) -> MetricBase:
47 return WorkflowMetric()
48
49 @classmethod
50 def _build_from_config(cls, config: DictConfig):
51 workflow = super()._build_from_config(config)
52
53 def _rl_sample_and_update_fn(carry, unused_t):
54 key, agent_state, opt_state, replay_buffer_state = carry
55
56 def _sample_fn(key):
57 return workflow.replay_buffer.sample(replay_buffer_state, key)
58
59 key, rb_key, learn_key = jax.random.split(key, 3)
60 rb_keys = jax.random.split(
61 rb_key,
62 config.actor_update_interval * config.num_learning_offspring,
63 )
64 sample_batches = jax.vmap(_sample_fn)(rb_keys)
65
66 # (actor_update_interval, num_learning_offspring, B, ...)
67 sample_batches = jtu.tree_map(
68 lambda x: x.reshape(
69 (
70 config.actor_update_interval,
71 config.num_learning_offspring,
72 *x.shape[1:],
73 )
74 ),
75 sample_batches,
76 )
77
78 (agent_state, opt_state), train_info = workflow._rl_update_fn(
79 agent_state, opt_state, sample_batches, learn_key
80 )
81
82 return (key, agent_state, opt_state, replay_buffer_state), train_info
83
84 if is_jitted(cls.evaluate):
85 _rl_sample_and_update_fn = jax.jit(_rl_sample_and_update_fn)
86
87 workflow._rl_sample_and_update_fn = _rl_sample_and_update_fn
88
89 return workflow
90
91 def _ec_sample(self, ec_opt_state):
92 return self.ec_optimizer.ask(ec_opt_state)
93
94 def _rl_update(
95 self,
96 agent_state,
97 opt_state,
98 replay_buffer_state,
99 key,
100 num_updates,
101 ):
102 """Add num_updates support. Therefore this method cannot be jitted."""
103 (
104 (_, agent_state, opt_state, replay_buffer_state),
105 train_info,
106 ) = scan_and_last(
107 self._rl_sample_and_update_fn,
108 (key, agent_state, opt_state, replay_buffer_state),
109 (),
110 length=num_updates,
111 )
112
113 critic_loss, actor_loss, critic_loss_dict, actor_loss_dict = train_info
114
115 # smoothed td3 metrics
116 td3_metrics = TD3TrainMetric(
117 actor_loss=actor_loss,
118 critic_loss=critic_loss,
119 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
120 )
121
122 td3_metrics = td3_metrics.replace(
123 actor_loss=td3_metrics.actor_loss / self.config.num_learning_offspring
124 )
125
126 return td3_metrics, agent_state, opt_state
127
128 def _rollout_and_update(
129 self, pop_agent_state, replay_buffer_state, ec_opt_state, key
130 ):
131 """Calculate the fitness and update the replay buffer and ec_optimizer."""
132 # the trajectory [T, #pop*B, ...]
133 # metrics: [#pop, B]
134 eval_metrics, trajectory, replay_buffer_state = self._rollout(
135 pop_agent_state, replay_buffer_state, key
136 )
137
138 fitnesses = eval_metrics.episode_returns.mean(axis=-1)
139 ec_metrics, ec_opt_state = self.ec_optimizer.tell(ec_opt_state, fitnesses)
140
141 return eval_metrics, ec_metrics, fitnesses, replay_buffer_state, ec_opt_state
142
[docs]
143 def step(self, state: State) -> tuple[MetricBase, State]:
144 start_t = time.perf_counter()
145 pop_size = self.config.pop_size
146 agent_state = state.agent_state
147 opt_state = state.opt_state
148 ec_opt_state = state.ec_opt_state
149 replay_buffer_state = state.replay_buffer_state
150 iterations = state.metrics.iterations + 1
151
152 key, perm_key, rollout_key, learn_key = jax.random.split(state.key, num=4)
153
154 # ======= CEM Sample ========
155 pop_actor_params, ec_opt_state = self._ec_sample(ec_opt_state)
156
157 # ======== RL update ========
158
159 if iterations > self.config.warmup_iters:
160 learning_actor_indices = jax.random.choice(
161 perm_key,
162 self.config.pop_size,
163 (self.config.num_learning_offspring,),
164 replace=False,
165 )
166 learning_actor_params = tree_get(pop_actor_params, learning_actor_indices)
167 learning_agent_state = cemrl_replace_td3_actor_params(
168 agent_state, learning_actor_params
169 )
170 # reset and add actors' opt_state
171 learning_opt_state = opt_state.replace(
172 actor=self.optimizer.init(learning_actor_params),
173 )
174
175 num_updates = (
176 jnp.ceil(
177 state.metrics.sampled_timesteps_per_iter
178 * self.config.rl_updates_frac
179 ).astype(jnp.uint32)
180 // self.config.actor_update_interval
181 )
182
183 td3_metrics, learning_agent_state, learning_opt_state = self._rl_update(
184 learning_agent_state,
185 learning_opt_state,
186 replay_buffer_state,
187 learn_key,
188 num_updates,
189 )
190
191 pop_actor_params = tree_set(
192 pop_actor_params,
193 learning_agent_state.params.actor_params,
194 learning_actor_indices,
195 unique_indices=True,
196 )
197
198 # drop the actors and their opt_state
199 agent_state = cemrl_replace_td3_actor_params(
200 learning_agent_state, pop_actor_params=None
201 )
202 opt_state = learning_opt_state.replace(actor=None)
203
204 # rl injection
205 ec_opt_state = self._rl_injection(ec_opt_state, pop_actor_params)
206
207 else:
208 num_updates = jnp.zeros((), dtype=jnp.uint32)
209 td3_metrics = None
210
211 # ======== CEM update ========
212 pop_agent_state = cemrl_replace_td3_actor_params(agent_state, pop_actor_params)
213 eval_metrics, ec_metrics, fitnesses, replay_buffer_state, ec_opt_state = (
214 self._rollout_and_update(
215 pop_agent_state,
216 replay_buffer_state,
217 ec_opt_state,
218 rollout_key,
219 )
220 )
221
222 # adding debug info for CEM
223 ec_info = PyTreeDict(ec_metrics)
224 ec_info.cov_eps = ec_opt_state.cov_eps
225 if td3_metrics is not None:
226 elites_indices = jax.lax.top_k(fitnesses, self.config.num_elites)[1]
227 elites_from_rl = jnp.isin(learning_actor_indices, elites_indices)
228 ec_info.elites_from_rl = elites_from_rl.sum()
229 ec_info.elites_from_rl_ratio = elites_from_rl.mean()
230
231 train_metrics = CEMRLTrainMetric(
232 rb_size=replay_buffer_state.buffer_size,
233 pop_episode_lengths=eval_metrics.episode_lengths.mean(-1),
234 pop_episode_returns=eval_metrics.episode_returns.mean(-1),
235 rl_metrics=td3_metrics,
236 ec_info=ec_info,
237 num_updates_per_iter=num_updates,
238 time_cost_per_iter=time.perf_counter() - start_t,
239 )
240
241 # calculate the number of timestep
242 sampled_timesteps = eval_metrics.episode_lengths.sum().astype(jnp.uint32)
243 sampled_episodes = jnp.uint32(self.config.episodes_for_fitness * pop_size)
244
245 workflow_metrics = state.metrics.replace(
246 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
247 sampled_timesteps_per_iter=sampled_timesteps,
248 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
249 iterations=iterations,
250 )
251
252 state = state.replace(
253 key=key,
254 metrics=workflow_metrics,
255 agent_state=agent_state,
256 replay_buffer_state=replay_buffer_state,
257 ec_opt_state=ec_opt_state,
258 opt_state=opt_state,
259 )
260
261 return train_metrics, state
262
[docs]
263 @classmethod
264 def enable_jit(cls) -> None:
265 cls._ec_sample = jax.jit(cls._ec_sample, static_argnums=(0,))
266 cls._rollout_and_update = jax.jit(cls._rollout_and_update, static_argnums=(0,))
267
268 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
269 cls._postsetup_replaybuffer = jax.jit(
270 cls._postsetup_replaybuffer, static_argnums=(0,)
271 )