1import copy
2import logging
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4from omegaconf import DictConfig
5
6import chex
7import jax
8import jax.numpy as jnp
9import optax
10
11from evorl.agent import AgentStateAxis
12from evorl.metrics import MetricBase, metric_field
13from evorl.types import PyTreeDict, State
14from evorl.utils import running_statistics
15from evorl.utils.jax_utils import tree_stop_gradient
16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
17from evorl.evaluators import Evaluator, EpisodeCollector
18from evorl.sample_batch import SampleBatch
19from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
20from evorl.agent import Agent, AgentState, RandomAgent
21from evorl.envs import create_env, AutoresetMode, Env
22from evorl.workflows import Workflow
23from evorl.rollout import rollout
24from evorl.ec.optimizers import EvoOptimizer, ECState
25
26from ..offpolicy_utils import clean_trajectory
27
28logger = logging.getLogger(__name__)
29
30
[docs]
31class ERLTrainMetric(MetricBase):
32 pop_episode_returns: chex.Array | None = None
33 pop_episode_lengths: chex.Array | None = None
34 rb_size: chex.Array | None = None
35 rl_episode_returns: chex.Array | None = None
36 rl_episode_lengths: chex.Array | None = None
37 rl_metrics: MetricBase | None = None
38 ec_info: PyTreeDict = metric_field(default_factory=PyTreeDict)
39
40
[docs]
41class WorkflowMetric(MetricBase):
42 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32)
43 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
44 rl_sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32)
45 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
46
47
[docs]
48class ERLWorkflowBase(Workflow):
49 def __init__(
50 self,
51 *,
52 env: Env,
53 agent: Agent,
54 agent_state_vmap_axes: AgentStateAxis,
55 optimizer: optax.GradientTransformation,
56 ec_optimizer: EvoOptimizer,
57 ec_collector: EpisodeCollector,
58 rl_collector: EpisodeCollector,
59 evaluator: Evaluator, # to evaluate the pop-mean actor
60 replay_buffer: AbstractReplayBuffer,
61 config: DictConfig,
62 ):
63 super().__init__(config)
64 self.env = env
65 self.agent = agent
66 self.agent_state_vmap_axes = agent_state_vmap_axes
67 self.optimizer = optimizer
68 self.ec_optimizer = ec_optimizer
69 self.ec_collector = ec_collector
70 self.rl_collector = rl_collector
71 self.evaluator = evaluator
72 self.replay_buffer = replay_buffer
73
74 self.devices = jax.local_devices()[:1]
75
[docs]
76 @classmethod
77 def build_from_config(
78 cls,
79 config: DictConfig,
80 enable_multi_devices: bool = False,
81 enable_jit: bool = True,
82 ) -> Self:
83 config = copy.deepcopy(config) # avoid in-place modification
84
85 devices = jax.local_devices()
86
87 if enable_multi_devices or len(devices) > 1:
88 raise NotImplementedError("Multi-devices is not supported yet.")
89
90 if enable_jit:
91 cls.enable_jit()
92
93 workflow = cls._build_from_config(config)
94
95 return workflow
96
97 @classmethod
98 def _build_from_config(cls, config: DictConfig) -> Self:
99 raise NotImplementedError
100
[docs]
101 def setup(self, key: chex.PRNGKey) -> State:
102 key, agent_key, rb_key = jax.random.split(key, 3)
103
104 agent_state, opt_state, ec_opt_state = self._setup_agent_and_optimizer(
105 agent_key
106 )
107 workflow_metrics = self._setup_workflow_metrics()
108 replay_buffer_state = self._setup_replaybuffer(rb_key)
109
110 # =======================
111
112 state = State(
113 key=key,
114 metrics=workflow_metrics,
115 agent_state=agent_state,
116 opt_state=opt_state,
117 ec_opt_state=ec_opt_state,
118 replay_buffer_state=replay_buffer_state,
119 )
120
121 if self.config.random_timesteps > 0:
122 logger.info("Start replay buffer post-setup")
123 state = self._postsetup_replaybuffer(state)
124 logger.info("Complete replay buffer post-setup")
125
126 return state
127
128 def _setup_workflow_metrics(self) -> MetricBase:
129 return WorkflowMetric()
130
131 def _setup_agent_and_optimizer(
132 self, key: chex.PRNGKey
133 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
134 raise NotImplementedError
135
136 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState:
137 action_space = self.env.action_space
138 obs_space = self.env.obs_space
139
140 # create dummy data to initialize the replay buffer
141 dummy_action = jnp.zeros(action_space.shape)
142 dummy_obs = obs_space.sample(key)
143
144 dummy_reward = jnp.zeros(())
145 dummy_done = jnp.zeros(())
146
147 dummy_sample_batch = SampleBatch(
148 obs=dummy_obs,
149 actions=dummy_action,
150 rewards=dummy_reward,
151 # next_obs=dummy_obs,
152 # dones=dummy_done,
153 extras=PyTreeDict(
154 policy_extras=PyTreeDict(),
155 env_extras=PyTreeDict(
156 {"ori_obs": dummy_obs, "termination": dummy_done}
157 ),
158 ),
159 )
160 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch)
161
162 return replay_buffer_state
163
164 def _postsetup_replaybuffer(self, state: State) -> State:
165 action_space = self.env.action_space
166 obs_space = self.env.obs_space
167 config = self.config
168
169 replay_buffer_state = state.replay_buffer_state
170 agent_state = state.agent_state
171
172 # We need a separate autoreset env to fill the replay buffer
173 env = create_env(
174 config.env,
175 episode_length=config.env.max_episode_steps,
176 parallel=config.num_envs,
177 autoreset_mode=AutoresetMode.NORMAL,
178 record_ori_obs=True,
179 )
180
181 # ==== fill random transitions ====
182
183 key, env_key, rollout_key = jax.random.split(state.key, num=3)
184 random_agent = RandomAgent()
185 random_agent_state = random_agent.init(
186 obs_space, action_space, jax.random.PRNGKey(0)
187 )
188 rollout_length = config.random_timesteps // config.num_envs
189
190 env_state = env.reset(env_key)
191 trajectory, env_state = rollout(
192 env_fn=env.step,
193 action_fn=random_agent.compute_actions,
194 env_state=env_state,
195 agent_state=random_agent_state,
196 key=rollout_key,
197 rollout_length=rollout_length,
198 env_extra_fields=("ori_obs", "termination"),
199 )
200
201 # sampled_timesteps = jnp.uint32(rollout_length * config.num_envs)
202 # # Since we sample from autoreset env, this metric might not be accurate:
203 # sampled_episodes = trajectory.dones.sum()
204
205 # [T, B, ...] -> [T*B, ...]
206 trajectory = clean_trajectory(trajectory)
207 trajectory = flatten_rollout_trajectory(trajectory)
208 trajectory = tree_stop_gradient(trajectory)
209
210 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
211
212 if agent_state.obs_preprocessor_state is not None and rollout_length > 0:
213 agent_state = agent_state.replace(
214 obs_preprocessor_state=running_statistics.update(
215 agent_state.obs_preprocessor_state, trajectory.obs
216 )
217 )
218
219 # Note: we don't count the random transitions in the metrics, to align the x-axis(sampled_episodes/iterations) over different runs
220 # workflow_metrics = state.metrics.replace(
221 # sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
222 # sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
223 # )
224
225 return state.replace(
226 key=key,
227 # metrics=workflow_metrics,
228 agent_state=agent_state,
229 replay_buffer_state=replay_buffer_state,
230 )
231
[docs]
232 def warmup_step(self, state: State) -> tuple[MetricBase, State]:
233 raise NotImplementedError
234
235 def _rl_injection(self, *args, **kwargs):
236 raise NotImplementedError
237
[docs]
238 def evaluate(self, state: State) -> tuple[MetricBase, State]:
239 raise NotImplementedError
240
[docs]
241 @classmethod
242 def enable_jit(cls) -> None:
243 cls.step = jax.jit(cls.step, static_argnums=(0,))
244
245 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
246 cls._postsetup_replaybuffer = jax.jit(
247 cls._postsetup_replaybuffer, static_argnums=(0,)
248 )
249 cls.warmup_step = jax.jit(cls.warmup_step, static_argnums=(0,))