1from collections.abc import Sequence
2from functools import partial
3from typing import Protocol
4
5import chex
6import jax
7import jax.numpy as jnp
8
9from evorl.agent import AgentActionFn, AgentState
10from evorl.envs import EnvState, EnvStepFn
11from evorl.sample_batch import SampleBatch
12from evorl.types import PyTreeDict
13from evorl.utils.jax_utils import rng_split
14
15# TODO: add RNN Policy support
16
17__all__ = [
18 "rollout",
19 "eval_rollout_episode",
20 "fast_eval_rollout_episode",
21]
22
23
[docs]
24class RolloutFn(Protocol):
25 def __call__(
26 self,
27 env_fn: EnvStepFn,
28 action_fn: AgentActionFn,
29 env_state: EnvState,
30 agent_state: AgentState,
31 key: chex.PRNGKey,
32 rollout_length: int,
33 *args,
34 **kwargs,
35 ):
36 pass
37
38
[docs]
39def env_step(
40 env_fn: EnvStepFn,
41 action_fn: AgentActionFn,
42 env_state: EnvState,
43 agent_state: AgentState, # readonly
44 key: chex.PRNGKey,
45 env_extra_fields: Sequence[str] = (),
46) -> tuple[SampleBatch, EnvState]:
47 """Collect one-step data."""
48 # sample_batch: [#envs, ...]
49 sample_batch = SampleBatch(obs=env_state.obs)
50
51 actions, policy_extras = action_fn(agent_state, sample_batch, key)
52 env_nstate = env_fn(env_state, actions)
53
54 info = env_nstate.info
55 env_extras = PyTreeDict({x: info[x] for x in env_extra_fields if x in info})
56
57 transition = SampleBatch(
58 obs=env_state.obs,
59 actions=actions,
60 rewards=env_nstate.reward,
61 dones=env_nstate.done,
62 next_obs=env_nstate.obs,
63 extras=PyTreeDict(policy_extras=policy_extras, env_extras=env_extras),
64 )
65
66 return transition, env_nstate
67
68
[docs]
69def eval_env_step(
70 env_fn: EnvStepFn,
71 action_fn: AgentActionFn,
72 env_state: EnvState,
73 agent_state: AgentState, # readonly
74 key: chex.PRNGKey,
75) -> tuple[SampleBatch, EnvState]:
76 """Collect one-step data in evaluation mode."""
77 # sample_batch: [#envs, ...]
78 sample_batch = SampleBatch(obs=env_state.obs)
79
80 actions, policy_extras = action_fn(agent_state, sample_batch, key)
81 env_nstate = env_fn(env_state, actions)
82
83 transition = SampleBatch(
84 rewards=env_nstate.reward,
85 dones=env_nstate.done,
86 )
87
88 return transition, env_nstate
89
90
[docs]
91def rollout(
92 env_fn: EnvStepFn,
93 action_fn: AgentActionFn,
94 env_state: EnvState,
95 agent_state: AgentState,
96 key: chex.PRNGKey,
97 rollout_length: int,
98 env_extra_fields: Sequence[str] = (),
99) -> tuple[SampleBatch, EnvState]:
100 """Collect trajectories with length of `rollout_length`.
101
102 This method is a general rollout method used for collecting trajectories from a vectorized env. When the env enables autoreset, the returned sequential trajactory data could contain segments from multiple episodes.
103
104 Args:
105 env_fn: `step` function of a vmapped env.
106 action_fn: The agent's action function. Eg: `agent.compute_actions`.
107 env_state: State of the environment.
108 agent_state: State of the agent.
109 key: PRNG key.
110 rollout_length: The length of the trajectory to collect.
111 env_extra_fields: Extra fields collected from `env_state.info` into `trajectory.extras.env_extras`.
112
113 Returns:
114 A tuple (trajectory, env_state).
115 - trajectory: `SampleBatch` object with shape (T, B, ...), where T=rollout_length, B=#envs in `env_fn`.
116 - env_state: last env_state after rollout
117
118 """
119
120 def _one_step_rollout(carry, unused_t):
121 env_state, current_key = carry
122 next_key, current_key = rng_split(current_key, 2)
123
124 # transition: [#envs, ...]
125 transition, env_nstate = env_step(
126 env_fn,
127 action_fn,
128 env_state,
129 agent_state,
130 current_key,
131 env_extra_fields,
132 )
133
134 return (env_nstate, next_key), transition
135
136 # trajectory: [T, #envs, ...]
137 (env_state, _), trajectory = jax.lax.scan(
138 _one_step_rollout, (env_state, key), (), length=rollout_length
139 )
140
141 return trajectory, env_state
142
143
[docs]
144def eval_rollout_episode(
145 env_fn: EnvStepFn,
146 action_fn: AgentActionFn,
147 env_state: EnvState,
148 agent_state: AgentState,
149 key: chex.PRNGKey,
150 rollout_length: int,
151) -> tuple[SampleBatch, EnvState]:
152 """Evaulate a batch of episodic trajectories.
153
154 It avoids unnecessary calls of `env_step()` when all environments are done. However, the agent's action function will still be called after that. When the function is wrapped by `jax.vmap()`, this mechanism will not work.
155
156 Args:
157 env_fn: `step` function of a vmapped env without autoreset.
158 action_fn: The agent's action function. Eg: `agent.compute_actions`.
159 env_state: State of the environment.
160 agent_state: State of the agent.
161 key: PRNG key.
162 rollout_length: The length of the episodes. This value usually keeps the same as the env's `max_episode_steps` or be smllar than that.
163
164 Returns:
165 A tuple (trajectory, env_state).
166 - trajectory: SampleBatch with shape (T, B, ...), where T=rollout_length, B=#envs. When a episode is terminated
167 - env_state: last env_state after rollout
168 """
169 _eval_env_step = partial(eval_env_step, env_fn, action_fn)
170
171 def _one_step_rollout(carry, unused_t):
172 env_state, current_key, prev_transition = carry
173 next_key, current_key = rng_split(current_key, 2)
174
175 transition, env_nstate = jax.lax.cond(
176 env_state.done.all(),
177 lambda *x: (prev_transition.replace(), env_state.replace()),
178 _eval_env_step,
179 env_state,
180 agent_state,
181 current_key,
182 )
183
184 return (env_nstate, next_key, transition), transition
185
186 # run one-step rollout first to get bootstrap transition
187 # it will not include in the trajectory when env_state is from env.reset()
188 # this is manually controlled by user.
189 bootstrap_transition, _ = _eval_env_step(env_state, agent_state, key)
190
191 (env_state, _, _), trajectory = jax.lax.scan(
192 _one_step_rollout,
193 (env_state, key, bootstrap_transition),
194 (),
195 length=rollout_length,
196 )
197
198 return trajectory, env_state
199
200
[docs]
201def fast_eval_rollout_episode(
202 env_fn: EnvStepFn,
203 action_fn: AgentActionFn,
204 env_state: EnvState,
205 agent_state: AgentState,
206 key: chex.PRNGKey,
207 rollout_length: int,
208) -> tuple[PyTreeDict, EnvState]:
209 """Fast evaulate a batch of episodic trajectories.
210
211 A even faster implementation than `eval_rollout_episode()`. It achieves early termination when it is not wrapped by `jax.vmap()`. However, this method does not collect the trajectory data, it only returns the aggregated metrics dict with keys (episode_returns, episode_lengths), which is useful in cases like evaluation.
212
213 Args:
214 env_fn: `step` function of a vmapped env without autoreset.
215 action_fn: The agent's action function. Eg: `agent.compute_actions`.
216 env_state: State of the environment.
217 agent_state: State of the agent.
218 key: PRNG key.
219 rollout_length: The length of the episodes. This value usually keeps the same as the env's `max_episode_steps`.
220
221 Returns:
222 metrics: Dict(episode_returns, episode_lengths)
223 env_state: Last env_state after evaluation.
224 """
225 _eval_env_step = partial(eval_env_step, env_fn, action_fn)
226
227 def _terminate_cond(carry):
228 env_state, current_key, prev_metrics = carry
229 return (prev_metrics.episode_lengths < rollout_length).all() & (
230 ~env_state.done.all()
231 )
232
233 def _one_step_rollout(carry):
234 env_state, current_key, prev_metrics = carry
235 next_key, current_key = rng_split(current_key, 2)
236
237 transition, env_nstate = _eval_env_step(env_state, agent_state, current_key)
238
239 prev_dones = env_state.done
240
241 metrics = PyTreeDict(
242 episode_returns=prev_metrics.episode_returns
243 + (1 - prev_dones) * transition.rewards,
244 episode_lengths=prev_metrics.episode_lengths
245 + (1 - prev_dones).astype(jnp.int32),
246 )
247
248 return env_nstate, next_key, metrics
249
250 batch_shape = env_state.reward.shape
251
252 env_state, _, metrics = jax.lax.while_loop(
253 _terminate_cond,
254 _one_step_rollout,
255 (
256 env_state,
257 key,
258 PyTreeDict(
259 episode_returns=jnp.zeros(batch_shape),
260 episode_lengths=jnp.zeros(batch_shape, dtype=jnp.int32),
261 ),
262 ),
263 )
264
265 return metrics, env_state