1import copy
2import logging
3from typing_extensions import Self # pytype: disable=not-supported-yet]
4from omegaconf import DictConfig
5
6import chex
7import jax
8import jax.numpy as jnp
9import optax
10
11from evorl.agent import AgentStateAxis
12from evorl.metrics import MetricBase, WorkflowMetric, metric_field
13from evorl.types import PyTreeDict, State
14from evorl.utils import running_statistics
15from evorl.utils.jax_utils import tree_stop_gradient
16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
17from evorl.evaluators import Evaluator, EpisodeCollector
18from evorl.sample_batch import SampleBatch
19from evorl.replay_buffers import AbstractReplayBuffer, ReplayBufferState
20from evorl.agent import Agent, AgentState, RandomAgent
21from evorl.envs import create_env, AutoresetMode, Env
22from evorl.workflows import Workflow
23from evorl.rollout import rollout
24from evorl.ec.optimizers import EvoOptimizer, ECState
25
26from ..offpolicy_utils import clean_trajectory
27
28logger = logging.getLogger(__name__)
29
30
[docs]
31class CEMRLTrainMetric(MetricBase):
32 rb_size: chex.Array
33 pop_episode_returns: chex.Array
34 pop_episode_lengths: chex.Array
35 rl_metrics: MetricBase | None = None
36 ec_info: PyTreeDict = metric_field(default_factory=PyTreeDict)
37
38
[docs]
39class CEMRLWorkflowBase(Workflow):
40 """Base Class for CEMRL, equipped with many useful methods."""
41
42 def __init__(
43 self,
44 *,
45 env: Env,
46 agent: Agent,
47 agent_state_vmap_axes: AgentStateAxis,
48 optimizer: optax.GradientTransformation,
49 ec_optimizer: EvoOptimizer,
50 collector: EpisodeCollector,
51 evaluator: Evaluator, # to evaluate the pop-mean actor
52 replay_buffer: AbstractReplayBuffer,
53 config: DictConfig,
54 ):
55 super().__init__(config)
56 self.env = env
57 self.agent = agent
58 self.agent_state_vmap_axes = agent_state_vmap_axes
59 self.optimizer = optimizer
60 self.ec_optimizer = ec_optimizer
61 self.collector = collector
62 self.evaluator = evaluator
63 self.replay_buffer = replay_buffer
64
65 self.devices = jax.local_devices()[:1]
66
[docs]
67 @classmethod
68 def build_from_config(
69 cls,
70 config: DictConfig,
71 enable_multi_devices: bool = False,
72 enable_jit: bool = True,
73 ) -> Self:
74 config = copy.deepcopy(config) # avoid in-place modification
75
76 devices = jax.local_devices()
77
78 if enable_multi_devices or len(devices) > 1:
79 raise NotImplementedError("Multi-devices is not supported yet.")
80
81 if enable_jit:
82 cls.enable_jit()
83
84 workflow = cls._build_from_config(config)
85
86 return workflow
87
88 @classmethod
89 def _build_from_config(cls, config: DictConfig) -> Self:
90 raise NotImplementedError
91
[docs]
92 def setup(self, key: chex.PRNGKey) -> State:
93 key, agent_key, rb_key = jax.random.split(key, 3)
94
95 # [#pop, ...]
96 agent_state, opt_state, ec_opt_state = self._setup_agent_and_optimizer(
97 agent_key
98 )
99
100 workflow_metrics = self._setup_workflow_metrics()
101
102 replay_buffer_state = self._setup_replaybuffer(rb_key)
103
104 # =======================
105
106 state = State(
107 key=key,
108 metrics=workflow_metrics,
109 agent_state=agent_state,
110 opt_state=opt_state,
111 ec_opt_state=ec_opt_state,
112 replay_buffer_state=replay_buffer_state,
113 )
114
115 if self.config.random_timesteps > 0:
116 logger.info("Start replay buffer post-setup")
117 state = self._postsetup_replaybuffer(state)
118 logger.info("Complete replay buffer post-setup")
119
120 return state
121
122 def _setup_workflow_metrics(self) -> MetricBase:
123 return WorkflowMetric()
124
125 def _setup_agent_and_optimizer(
126 self, key: chex.PRNGKey
127 ) -> tuple[AgentState, chex.ArrayTree, ECState]:
128 raise NotImplementedError
129
130 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState:
131 action_space = self.env.action_space
132 obs_space = self.env.obs_space
133
134 # create dummy data to initialize the replay buffer
135 dummy_action = jnp.zeros(action_space.shape)
136 dummy_obs = obs_space.sample(key)
137
138 dummy_reward = jnp.zeros(())
139 dummy_done = jnp.zeros(())
140
141 dummy_sample_batch = SampleBatch(
142 obs=dummy_obs,
143 actions=dummy_action,
144 rewards=dummy_reward,
145 # next_obs=dummy_obs,
146 # dones=dummy_done,
147 extras=PyTreeDict(
148 policy_extras=PyTreeDict(),
149 env_extras=PyTreeDict(
150 {"ori_obs": dummy_obs, "termination": dummy_done}
151 ),
152 ),
153 )
154 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch)
155
156 return replay_buffer_state
157
158 def _postsetup_replaybuffer(self, state: State) -> State:
159 action_space = self.env.action_space
160 obs_space = self.env.obs_space
161 config = self.config
162
163 replay_buffer_state = state.replay_buffer_state
164 agent_state = state.agent_state
165
166 # We need a separate autoreset env to fill the replay buffer
167 env = create_env(
168 config.env,
169 episode_length=config.env.max_episode_steps,
170 parallel=config.num_envs,
171 autoreset_mode=AutoresetMode.NORMAL,
172 record_ori_obs=True,
173 )
174
175 # ==== fill random transitions ====
176
177 key, env_key, rollout_key = jax.random.split(state.key, num=3)
178 random_agent = RandomAgent()
179 random_agent_state = random_agent.init(
180 obs_space, action_space, jax.random.PRNGKey(0)
181 )
182 rollout_length = config.random_timesteps // config.num_envs
183
184 env_state = env.reset(env_key)
185 trajectory, env_state = rollout(
186 env_fn=env.step,
187 action_fn=random_agent.compute_actions,
188 env_state=env_state,
189 agent_state=random_agent_state,
190 key=rollout_key,
191 rollout_length=rollout_length,
192 env_extra_fields=("ori_obs", "termination"),
193 )
194
195 # sampled_timesteps = jnp.uint32(rollout_length * config.num_envs)
196 # # Since we sample from autoreset env, this metric might not be accurate:
197 # sampled_episodes = trajectory.dones.astype(jnp.uint32).sum()
198
199 # [T, B, ...] -> [T*B, ...]
200 trajectory = clean_trajectory(trajectory)
201 trajectory = flatten_rollout_trajectory(trajectory)
202 trajectory = tree_stop_gradient(trajectory)
203
204 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
205
206 if agent_state.obs_preprocessor_state is not None and rollout_length > 0:
207 agent_state = agent_state.replace(
208 obs_preprocessor_state=running_statistics.update(
209 agent_state.obs_preprocessor_state,
210 trajectory.obs,
211 )
212 )
213
214 # Note: we don't count the random transitions in the metrics, to align the x-axis(sampled_episodes/iterations) over different runs
215 # workflow_metrics = state.metrics.replace(
216 # sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
217 # sampled_episodes=state.metrics.sampled_episodes + sampled_episodes,
218 # )
219
220 return state.replace(
221 key=key,
222 # metrics=workflow_metrics,
223 agent_state=agent_state,
224 replay_buffer_state=replay_buffer_state,
225 )
226
227 def _rl_injection(self, *args, **kwargs):
228 raise NotImplementedError
229
[docs]
230 def evaluate(self, state: State) -> tuple[MetricBase, State]:
231 raise NotImplementedError
232
[docs]
233 @classmethod
234 def enable_jit(cls) -> None:
235 cls.step = jax.jit(cls.step, static_argnums=(0,))
236 cls.evaluate = jax.jit(cls.evaluate, static_argnums=(0,))
237 cls._postsetup_replaybuffer = jax.jit(
238 cls._postsetup_replaybuffer, static_argnums=(0,)
239 )