1import logging
2import time
3from omegaconf import DictConfig
4
5import chex
6import jax
7import jax.numpy as jnp
8import jax.tree_util as jtu
9
10from evorl.metrics import MetricBase
11from evorl.types import PyTreeDict, State
12from evorl.utils.jax_utils import is_jitted
13from evorl.algorithms.td3 import TD3TrainMetric
14
15from ..erl_workflow import ERLTrainMetric as ERLTrainMetricBase
16from .erl_td3_workflow import create_dummy_td3_trainmetric, erl_replace_td3_actor_params
17from .erl_ga import ERLGAWorkflow
18
19
20logger = logging.getLogger(__name__)
21
22
[docs]
23class ERLTrainMetric(ERLTrainMetricBase):
24 num_updates_per_iter: chex.Array = jnp.zeros((), dtype=jnp.uint32)
25 time_cost_per_iter: float = 0.0
26
27
[docs]
28class ERLWorkflow(ERLGAWorkflow):
29 """Original ERL impl.
30
31 Have dynamic training updates per iteration, i.e., #rl_updates = #sampled_timesteps_this_iter. Therefore the `step()` function cannot be directly jitted.
32 """
33
[docs]
34 @classmethod
35 def name(cls):
36 return "ERL-Origin"
37
38 @classmethod
39 def _build_from_config(cls, config: DictConfig):
40 workflow = super()._build_from_config(config)
41
42 def _rl_sample_and_update_fn(carry, unused_t):
43 key, agent_state, opt_state, replay_buffer_state, _ = carry
44
45 def _sample_fn(key):
46 return workflow.replay_buffer.sample(replay_buffer_state, key)
47
48 key, rb_key, learn_key = jax.random.split(key, 3)
49
50 rb_keys = jax.random.split(
51 rb_key, config.actor_update_interval * config.num_rl_agents
52 )
53 sample_batches = jax.vmap(_sample_fn)(rb_keys)
54
55 # (actor_update_interval, num_learning_offspring, B, ...)
56 sample_batches = jtu.tree_map(
57 lambda x: x.reshape(
58 (
59 config.actor_update_interval,
60 config.num_rl_agents,
61 *x.shape[1:],
62 )
63 ),
64 sample_batches,
65 )
66
67 (
68 (agent_state, opt_state),
69 (
70 critic_loss,
71 actor_loss,
72 critic_loss_dict,
73 actor_loss_dict,
74 ),
75 ) = workflow._rl_update_fn(
76 agent_state, opt_state, sample_batches, learn_key
77 )
78
79 td3_metrics = TD3TrainMetric(
80 actor_loss=actor_loss,
81 critic_loss=critic_loss,
82 raw_loss_dict=PyTreeDict({**critic_loss_dict, **actor_loss_dict}),
83 )
84
85 # Note: we do not put train_info into y_t for saving memory
86 return (key, agent_state, opt_state, replay_buffer_state, td3_metrics), None
87
88 if is_jitted(cls.evaluate):
89 _rl_sample_and_update_fn = jax.jit(_rl_sample_and_update_fn)
90
91 workflow._rl_sample_and_update_fn = _rl_sample_and_update_fn
92
93 return workflow
94
95 def _ec_update(self, ec_opt_state, fitnesses):
96 return self.ec_optimizer.tell(ec_opt_state, fitnesses)
97
98 def _ec_update_with_rl_injection(self, ec_opt_state, agent_state, fitnesses):
99 ec_opt_state = self._rl_injection(ec_opt_state, agent_state)
100 return self.ec_optimizer.tell_external(ec_opt_state, fitnesses)
101
102 def _rl_update(self, agent_state, opt_state, replay_buffer_state, key, num_updates):
103 # unlike erl-ga, since num_updates is large, we only use the last train_info
104 init_td3_metrics = create_dummy_td3_trainmetric(self.config.num_rl_agents)
105
106 (_, agent_state, opt_state, replay_buffer_state, td3_metrics), _ = jax.lax.scan(
107 self._rl_sample_and_update_fn,
108 (key, agent_state, opt_state, replay_buffer_state, init_td3_metrics),
109 (),
110 length=num_updates,
111 )
112
113 return td3_metrics, agent_state, opt_state
114
[docs]
115 def step(self, state: State) -> tuple[MetricBase, State]:
116 """The basic step function for the workflow to update agent."""
117 start_t = time.perf_counter()
118 pop_size = self.config.pop_size
119 agent_state = state.agent_state
120 opt_state = state.opt_state
121 ec_opt_state = state.ec_opt_state
122 replay_buffer_state = state.replay_buffer_state
123 iterations = state.metrics.iterations + 1
124
125 sampled_timesteps = jnp.zeros((), dtype=jnp.uint32)
126 sampled_episodes = jnp.zeros((), dtype=jnp.uint32)
127
128 key, ec_rollout_key, rl_rollout_key, learn_key = jax.random.split(
129 state.key, num=4
130 )
131
132 # ======== EC rollout ========
133 # the trajectory [#pop, T, B, ...]
134 # metrics: [#pop, B]
135 pop_actor_params, ec_opt_state = self.ec_optimizer.ask(ec_opt_state)
136 pop_agent_state = erl_replace_td3_actor_params(agent_state, pop_actor_params)
137 ec_eval_metrics, ec_trajectory, replay_buffer_state = self._ec_rollout(
138 pop_agent_state, replay_buffer_state, ec_rollout_key
139 )
140
141 # calculate the number of timestep
142 sampled_timesteps += ec_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
143 sampled_episodes += jnp.uint32(self.config.episodes_for_fitness * pop_size)
144
145 train_metrics = ERLTrainMetric(
146 pop_episode_lengths=ec_eval_metrics.episode_lengths.mean(-1),
147 pop_episode_returns=ec_eval_metrics.episode_returns.mean(-1),
148 )
149
150 # ======== RL update ========
151
152 rl_eval_metrics, rl_trajectory, replay_buffer_state = self._rl_rollout(
153 agent_state, replay_buffer_state, rl_rollout_key
154 )
155
156 if self.config.rl_updates_mode == "global": # same as original ERL
157 total_timesteps = state.metrics.sampled_timesteps + sampled_timesteps
158 num_updates = (
159 jnp.ceil(total_timesteps * self.config.rl_updates_frac).astype(
160 jnp.uint32
161 )
162 // self.config.actor_update_interval
163 )
164 elif self.config.rl_updates_mode == "iter":
165 num_updates = (
166 jnp.ceil(sampled_timesteps * self.config.rl_updates_frac).astype(
167 jnp.uint32
168 )
169 // self.config.actor_update_interval
170 )
171 else:
172 raise ValueError(f"Unknown rl_updates_mode: {self.config.rl_updates_mode}")
173
174 td3_metrics, agent_state, opt_state = self._rl_update(
175 agent_state, opt_state, replay_buffer_state, learn_key, num_updates
176 )
177
178 # get average loss
179 td3_metrics = td3_metrics.replace(
180 actor_loss=td3_metrics.actor_loss / self.config.num_rl_agents,
181 critic_loss=td3_metrics.critic_loss / self.config.num_rl_agents,
182 )
183
184 train_metrics = train_metrics.replace(
185 num_updates_per_iter=num_updates,
186 rl_episode_lengths=rl_eval_metrics.episode_lengths.mean(-1),
187 rl_episode_returns=rl_eval_metrics.episode_returns.mean(-1),
188 rl_metrics=td3_metrics,
189 )
190
191 rl_sampled_timesteps = rl_eval_metrics.episode_lengths.sum().astype(jnp.uint32)
192 sampled_timesteps += rl_sampled_timesteps
193 sampled_episodes += jnp.uint32(
194 self.config.num_rl_agents * self.config.rollout_episodes
195 )
196
197 # ======== EC update ========
198 fitnesses = ec_eval_metrics.episode_returns.mean(axis=-1)
199
200 if iterations % self.config.rl_injection_interval == 0:
201 ec_metrics, ec_opt_state = self._ec_update_with_rl_injection(
202 ec_opt_state, agent_state, fitnesses
203 )
204 else:
205 ec_metrics, ec_opt_state = self._ec_update(ec_opt_state, fitnesses)
206
207 train_metrics = train_metrics.replace(
208 ec_info=ec_metrics,
209 rb_size=replay_buffer_state.buffer_size,
210 time_cost_per_iter=time.perf_counter() - start_t,
211 )
212
213 # iterations is the number of updates of the agent
214 workflow_metrics = state.metrics.replace(
215 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
216 sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
217 rl_sampled_timesteps=state.metrics.rl_sampled_timesteps
218 + rl_sampled_timesteps,
219 iterations=iterations,
220 )
221
222 state = state.replace(
223 key=key,
224 metrics=workflow_metrics,
225 agent_state=agent_state,
226 replay_buffer_state=replay_buffer_state,
227 ec_opt_state=ec_opt_state,
228 opt_state=opt_state,
229 )
230
231 return train_metrics, state
232
[docs]
233 @classmethod
234 def enable_jit(cls) -> None:
235 # Do not jit replay buffer add
236
237 cls._rl_rollout = jax.jit(cls._rl_rollout, static_argnums=(0,))
238 cls._ec_rollout = jax.jit(cls._ec_rollout, static_argnums=(0,))
239 cls._ec_update = jax.jit(cls._ec_update, static_argnums=(0,))
240 cls._ec_update_with_rl_injection = jax.jit(
241 cls._ec_update_with_rl_injection, static_argnums=(0,)
242 )
243
244 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
245 cls._postsetup_replaybuffer = jax.jit(
246 cls._postsetup_replaybuffer, static_argnums=(0,)
247 )
248
249
[docs]
250def get_ec_pop_statistics(pop):
251 pop = pop["params"]
252
253 def _get_stats(x):
254 return dict(
255 min=jnp.min(x).tolist(),
256 max=jnp.max(x).tolist(),
257 )
258
259 return jtu.tree_map(_get_stats, pop)