1import logging
2import math
3from omegaconf import DictConfig
4
5import jax
6import jax.numpy as jnp
7import chex
8
9from evorl.replay_buffers import ReplayBufferState
10from evorl.envs import Discrete
11from evorl.distributed.comm import psum
12from evorl.workflows import OffPolicyWorkflow
13from evorl.sample_batch import SampleBatch
14from evorl.types import State, PyTreeDict
15from evorl.rollout import rollout
16from evorl.utils.rl_toolkits import flatten_rollout_trajectory
17from evorl.utils import running_statistics
18from evorl.utils.jax_utils import tree_stop_gradient, scan_and_last
19from evorl.agent import RandomAgent
20from evorl.recorders import add_prefix
21
22
23logger = logging.getLogger(__name__)
24
25
[docs]
26class OffPolicyWorkflowTemplate(OffPolicyWorkflow):
27 """Wrapping some common template for off-policy RL with TD Learning."""
28
29 @classmethod
30 def _rescale_config(cls, config: DictConfig) -> None:
31 num_devices = jax.device_count()
32
33 if config.num_envs % num_devices != 0:
34 logger.warning(
35 f"num_envs({config.num_envs}) cannot be divided by num_devices({num_devices}), "
36 f"rescale num_envs to {config.num_envs // num_devices}"
37 )
38 if config.num_eval_envs % num_devices != 0:
39 logger.warning(
40 f"num_eval_envs({config.num_eval_envs}) cannot be divided by num_devices({num_devices}), "
41 f"rescale num_eval_envs to {config.num_eval_envs // num_devices}"
42 )
43 if config.replay_buffer_capacity % num_devices != 0:
44 logger.warning(
45 f"replay_buffer_capacity({config.replay_buffer_capacity}) cannot be divided by num_devices({num_devices}), "
46 f"rescale replay_buffer_capacity to {config.replay_buffer_capacity // num_devices}"
47 )
48 if config.random_timesteps % num_devices != 0:
49 logger.warning(
50 f"random_timesteps({config.random_timesteps}) cannot be divided by num_devices({num_devices}), "
51 f"rescale random_timesteps to {config.random_timesteps // num_devices}"
52 )
53 if config.learning_start_timesteps % num_devices != 0:
54 logger.warning(
55 f"learning_start_timesteps({config.learning_start_timesteps}) cannot be divided by num_devices({num_devices}), "
56 f"rescale learning_start_timesteps to {config.learning_start_timesteps // num_devices}"
57 )
58
59 config.num_envs = config.num_envs // num_devices
60 config.num_eval_envs = config.num_eval_envs // num_devices
61 config.replay_buffer_capacity = config.replay_buffer_capacity // num_devices
62 config.random_timesteps = config.random_timesteps // num_devices
63 config.learning_start_timesteps = config.learning_start_timesteps // num_devices
64
65 def _setup_replaybuffer(self, key: chex.PRNGKey) -> ReplayBufferState:
66 action_space = self.env.action_space
67 obs_space = self.env.obs_space
68
69 # create dummy data to initialize the replay buffer
70 if isinstance(action_space, Discrete):
71 dummy_action = jnp.zeros((), dtype=jnp.int32)
72 else:
73 dummy_action = jnp.zeros(action_space.shape)
74 dummy_obs = obs_space.sample(key)
75 dummy_reward = jnp.zeros(())
76 dummy_done = jnp.zeros(())
77
78 dummy_sample_batch = SampleBatch(
79 obs=dummy_obs,
80 actions=dummy_action,
81 rewards=dummy_reward,
82 # next_obs=dummy_obs,
83 # dones=dummy_done,
84 extras=PyTreeDict(
85 policy_extras=PyTreeDict(),
86 env_extras=PyTreeDict(
87 {"ori_obs": dummy_obs, "termination": dummy_done}
88 ),
89 ),
90 )
91 replay_buffer_state = self.replay_buffer.init(dummy_sample_batch)
92
93 return replay_buffer_state
94
95 def _postsetup_replaybuffer(self, state: State) -> State:
96 action_space = self.env.action_space
97 obs_space = self.env.obs_space
98 config = self.config
99 replay_buffer_state = state.replay_buffer_state
100 agent_state = state.agent_state
101
102 def _rollout(agent, agent_state, key, rollout_length):
103 env_key, rollout_key = jax.random.split(key)
104
105 env_state = self.env.reset(env_key)
106
107 trajectory, env_state = rollout(
108 env_fn=self.env.step,
109 action_fn=agent.compute_actions,
110 env_state=env_state,
111 agent_state=agent_state,
112 key=rollout_key,
113 rollout_length=rollout_length,
114 env_extra_fields=("ori_obs", "termination"),
115 )
116
117 # [T, B, ...] -> [T*B, ...]
118 trajectory = clean_trajectory(trajectory)
119 trajectory = flatten_rollout_trajectory(trajectory)
120 trajectory = tree_stop_gradient(trajectory)
121
122 return trajectory
123
124 def _update_obs_preprocessor(agent_state, trajectory):
125 if (
126 agent_state.obs_preprocessor_state is not None
127 and len(trajectory.obs) > 0
128 ):
129 agent_state = agent_state.replace(
130 obs_preprocessor_state=running_statistics.update(
131 agent_state.obs_preprocessor_state,
132 trajectory.obs,
133 dp_axis_name=self.dp_axis_name,
134 )
135 )
136 return agent_state
137
138 # ==== fill random transitions ====
139
140 key, random_rollout_key, rollout_key = jax.random.split(state.key, num=3)
141 random_agent = RandomAgent()
142 random_agent_state = random_agent.init(
143 obs_space, action_space, jax.random.PRNGKey(0)
144 )
145 rollout_length = config.random_timesteps // config.num_envs
146
147 trajectory = _rollout(
148 random_agent,
149 random_agent_state,
150 key=random_rollout_key,
151 rollout_length=rollout_length,
152 )
153
154 agent_state = _update_obs_preprocessor(agent_state, trajectory)
155 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
156
157 rollout_timesteps = rollout_length * config.num_envs
158 sampled_timesteps = psum(
159 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name
160 )
161
162 # ==== fill tansition state from init agent ====
163 rollout_length = math.ceil(
164 (config.learning_start_timesteps - rollout_timesteps) / config.num_envs
165 )
166
167 trajectory = _rollout(
168 self.agent,
169 agent_state,
170 key=rollout_key,
171 rollout_length=rollout_length,
172 )
173
174 agent_state = _update_obs_preprocessor(agent_state, trajectory)
175 replay_buffer_state = self.replay_buffer.add(replay_buffer_state, trajectory)
176
177 rollout_timesteps = rollout_length * config.num_envs
178 sampled_timesteps = sampled_timesteps + psum(
179 jnp.uint32(rollout_timesteps), axis_name=self.dp_axis_name
180 )
181
182 workflow_metrics = state.metrics.replace(
183 sampled_timesteps=state.metrics.sampled_timesteps + sampled_timesteps,
184 ).all_reduce(dp_axis_name=self.dp_axis_name)
185
186 return state.replace(
187 key=key,
188 metrics=workflow_metrics,
189 agent_state=agent_state,
190 replay_buffer_state=replay_buffer_state,
191 )
192
193 def _multi_steps(self, state):
194 def _step(state, _):
195 train_metrics, state = self.step(state)
196 return state, train_metrics
197
198 state, train_metrics = scan_and_last(
199 _step, state, (), length=self.config.fold_iters
200 )
201
202 # jax.debug.print("train_metrics: {}", tree_has_nan(train_metrics))
203 # jax.debug.print("state: {}", tree_has_nan(state))
204
205 return train_metrics, state
206
[docs]
207 def learn(self, state: State) -> State:
208 num_devices = jax.device_count()
209 one_step_timesteps = self.config.rollout_length * self.config.num_envs
210 sampled_timesteps = state.metrics.sampled_timesteps.tolist()
211 num_iters = math.ceil(
212 (self.config.total_timesteps - sampled_timesteps)
213 / (one_step_timesteps * self.config.fold_iters * num_devices)
214 )
215 start_iteration = state.metrics.iterations.tolist()
216 final_iteration = num_iters + start_iteration
217
218 for i in range(num_iters):
219 train_metrics, state = self._multi_steps(state)
220 workflow_metrics = state.metrics
221
222 # current iteration
223 iterations = state.metrics.iterations.tolist()
224 self.recorder.write(train_metrics.to_local_dict(), iterations)
225 self.recorder.write(workflow_metrics.to_local_dict(), iterations)
226
227 if (
228 iterations % self.config.eval_interval == 0
229 or iterations == final_iteration
230 ):
231 eval_metrics, state = self.evaluate(state)
232 self.recorder.write(
233 add_prefix(eval_metrics.to_local_dict(), "eval"), iterations
234 )
235
236 saved_state = state
237 if not self.config.save_replay_buffer:
238 saved_state = skip_replay_buffer_state(saved_state)
239 self.checkpoint_manager.save(
240 iterations, saved_state, force=iterations == final_iteration
241 )
242
243 return state
244
[docs]
245 @classmethod
246 def enable_jit(cls) -> None:
247 super().enable_jit()
248 cls._postsetup_replaybuffer = jax.jit(
249 cls._postsetup_replaybuffer, static_argnums=(0,)
250 )
251 cls._multi_steps = jax.jit(cls._multi_steps, static_argnums=(0,))
252
253
[docs]
254def skip_replay_buffer_state(state: State) -> State:
255 """Utility function to remove replay_buffer_state from state.
256
257 Usually used when saving the off-policy workflow state to disk.
258 """
259 return state.replace(replay_buffer_state=None)
260
261
[docs]
262def clean_trajectory(trajectory: SampleBatch) -> SampleBatch:
263 """Clean the trajectory to make it suitable for the replay buffer."""
264 return trajectory.replace(
265 next_obs=None,
266 dones=None,
267 )