Source code for evorl.rollout

  1from collections.abc import Sequence
  2from functools import partial
  3from typing import Protocol
  4
  5import chex
  6import jax
  7import jax.numpy as jnp
  8
  9from evorl.agent import AgentActionFn, AgentState
 10from evorl.envs import EnvState, EnvStepFn
 11from evorl.sample_batch import SampleBatch
 12from evorl.types import PyTreeDict
 13from evorl.utils.jax_utils import rng_split
 14
 15# TODO: add RNN Policy support
 16
 17__all__ = [
 18    "rollout",
 19    "eval_rollout_episode",
 20    "fast_eval_rollout_episode",
 21]
 22
 23
[docs] 24class RolloutFn(Protocol): 25 def __call__( 26 self, 27 env_fn: EnvStepFn, 28 action_fn: AgentActionFn, 29 env_state: EnvState, 30 agent_state: AgentState, 31 key: chex.PRNGKey, 32 rollout_length: int, 33 *args, 34 **kwargs, 35 ): 36 pass
37 38
[docs] 39def env_step( 40 env_fn: EnvStepFn, 41 action_fn: AgentActionFn, 42 env_state: EnvState, 43 agent_state: AgentState, # readonly 44 key: chex.PRNGKey, 45 env_extra_fields: Sequence[str] = (), 46) -> tuple[SampleBatch, EnvState]: 47 """Collect one-step data.""" 48 # sample_batch: [#envs, ...] 49 sample_batch = SampleBatch(obs=env_state.obs) 50 51 actions, policy_extras = action_fn(agent_state, sample_batch, key) 52 env_nstate = env_fn(env_state, actions) 53 54 info = env_nstate.info 55 env_extras = PyTreeDict({x: info[x] for x in env_extra_fields if x in info}) 56 57 transition = SampleBatch( 58 obs=env_state.obs, 59 actions=actions, 60 rewards=env_nstate.reward, 61 dones=env_nstate.done, 62 next_obs=env_nstate.obs, 63 extras=PyTreeDict(policy_extras=policy_extras, env_extras=env_extras), 64 ) 65 66 return transition, env_nstate
67 68
[docs] 69def eval_env_step( 70 env_fn: EnvStepFn, 71 action_fn: AgentActionFn, 72 env_state: EnvState, 73 agent_state: AgentState, # readonly 74 key: chex.PRNGKey, 75) -> tuple[SampleBatch, EnvState]: 76 """Collect one-step data in evaluation mode.""" 77 # sample_batch: [#envs, ...] 78 sample_batch = SampleBatch(obs=env_state.obs) 79 80 actions, policy_extras = action_fn(agent_state, sample_batch, key) 81 env_nstate = env_fn(env_state, actions) 82 83 transition = SampleBatch( 84 rewards=env_nstate.reward, 85 dones=env_nstate.done, 86 ) 87 88 return transition, env_nstate
89 90
[docs] 91def rollout( 92 env_fn: EnvStepFn, 93 action_fn: AgentActionFn, 94 env_state: EnvState, 95 agent_state: AgentState, 96 key: chex.PRNGKey, 97 rollout_length: int, 98 env_extra_fields: Sequence[str] = (), 99) -> tuple[SampleBatch, EnvState]: 100 """Collect trajectories with length of `rollout_length`. 101 102 This method is a general rollout method used for collecting trajectories from a vectorized env. When the env enables autoreset, the returned sequential trajactory data could contain segments from multiple episodes. 103 104 Args: 105 env_fn: `step` function of a vmapped env. 106 action_fn: The agent's action function. Eg: `agent.compute_actions`. 107 env_state: State of the environment. 108 agent_state: State of the agent. 109 key: PRNG key. 110 rollout_length: The length of the trajectory to collect. 111 env_extra_fields: Extra fields collected from `env_state.info` into `trajectory.extras.env_extras`. 112 113 Returns: 114 A tuple (trajectory, env_state). 115 - trajectory: `SampleBatch` object with shape (T, B, ...), where T=rollout_length, B=#envs in `env_fn`. 116 - env_state: last env_state after rollout 117 118 """ 119 120 def _one_step_rollout(carry, unused_t): 121 env_state, current_key = carry 122 next_key, current_key = rng_split(current_key, 2) 123 124 # transition: [#envs, ...] 125 transition, env_nstate = env_step( 126 env_fn, 127 action_fn, 128 env_state, 129 agent_state, 130 current_key, 131 env_extra_fields, 132 ) 133 134 return (env_nstate, next_key), transition 135 136 # trajectory: [T, #envs, ...] 137 (env_state, _), trajectory = jax.lax.scan( 138 _one_step_rollout, (env_state, key), (), length=rollout_length 139 ) 140 141 return trajectory, env_state
142 143
[docs] 144def eval_rollout_episode( 145 env_fn: EnvStepFn, 146 action_fn: AgentActionFn, 147 env_state: EnvState, 148 agent_state: AgentState, 149 key: chex.PRNGKey, 150 rollout_length: int, 151) -> tuple[SampleBatch, EnvState]: 152 """Evaulate a batch of episodic trajectories. 153 154 It avoids unnecessary calls of `env_step()` when all environments are done. However, the agent's action function will still be called after that. When the function is wrapped by `jax.vmap()`, this mechanism will not work. 155 156 Args: 157 env_fn: `step` function of a vmapped env without autoreset. 158 action_fn: The agent's action function. Eg: `agent.compute_actions`. 159 env_state: State of the environment. 160 agent_state: State of the agent. 161 key: PRNG key. 162 rollout_length: The length of the episodes. This value usually keeps the same as the env's `max_episode_steps` or be smllar than that. 163 164 Returns: 165 A tuple (trajectory, env_state). 166 - trajectory: SampleBatch with shape (T, B, ...), where T=rollout_length, B=#envs. When a episode is terminated 167 - env_state: last env_state after rollout 168 """ 169 _eval_env_step = partial(eval_env_step, env_fn, action_fn) 170 171 def _one_step_rollout(carry, unused_t): 172 env_state, current_key, prev_transition = carry 173 next_key, current_key = rng_split(current_key, 2) 174 175 transition, env_nstate = jax.lax.cond( 176 env_state.done.all(), 177 lambda *x: (prev_transition.replace(), env_state.replace()), 178 _eval_env_step, 179 env_state, 180 agent_state, 181 current_key, 182 ) 183 184 return (env_nstate, next_key, transition), transition 185 186 # run one-step rollout first to get bootstrap transition 187 # it will not include in the trajectory when env_state is from env.reset() 188 # this is manually controlled by user. 189 bootstrap_transition, _ = _eval_env_step(env_state, agent_state, key) 190 191 (env_state, _, _), trajectory = jax.lax.scan( 192 _one_step_rollout, 193 (env_state, key, bootstrap_transition), 194 (), 195 length=rollout_length, 196 ) 197 198 return trajectory, env_state
199 200
[docs] 201def fast_eval_rollout_episode( 202 env_fn: EnvStepFn, 203 action_fn: AgentActionFn, 204 env_state: EnvState, 205 agent_state: AgentState, 206 key: chex.PRNGKey, 207 rollout_length: int, 208) -> tuple[PyTreeDict, EnvState]: 209 """Fast evaulate a batch of episodic trajectories. 210 211 A even faster implementation than `eval_rollout_episode()`. It achieves early termination when it is not wrapped by `jax.vmap()`. However, this method does not collect the trajectory data, it only returns the aggregated metrics dict with keys (episode_returns, episode_lengths), which is useful in cases like evaluation. 212 213 Args: 214 env_fn: `step` function of a vmapped env without autoreset. 215 action_fn: The agent's action function. Eg: `agent.compute_actions`. 216 env_state: State of the environment. 217 agent_state: State of the agent. 218 key: PRNG key. 219 rollout_length: The length of the episodes. This value usually keeps the same as the env's `max_episode_steps`. 220 221 Returns: 222 metrics: Dict(episode_returns, episode_lengths) 223 env_state: Last env_state after evaluation. 224 """ 225 _eval_env_step = partial(eval_env_step, env_fn, action_fn) 226 227 def _terminate_cond(carry): 228 env_state, current_key, prev_metrics = carry 229 return (prev_metrics.episode_lengths < rollout_length).all() & ( 230 ~env_state.done.all() 231 ) 232 233 def _one_step_rollout(carry): 234 env_state, current_key, prev_metrics = carry 235 next_key, current_key = rng_split(current_key, 2) 236 237 transition, env_nstate = _eval_env_step(env_state, agent_state, current_key) 238 239 prev_dones = env_state.done 240 241 metrics = PyTreeDict( 242 episode_returns=prev_metrics.episode_returns 243 + (1 - prev_dones) * transition.rewards, 244 episode_lengths=prev_metrics.episode_lengths 245 + (1 - prev_dones).astype(jnp.int32), 246 ) 247 248 return env_nstate, next_key, metrics 249 250 batch_shape = env_state.reward.shape 251 252 env_state, _, metrics = jax.lax.while_loop( 253 _terminate_cond, 254 _one_step_rollout, 255 ( 256 env_state, 257 key, 258 PyTreeDict( 259 episode_returns=jnp.zeros(batch_shape), 260 episode_lengths=jnp.zeros(batch_shape, dtype=jnp.int32), 261 ), 262 ), 263 ) 264 265 return metrics, env_state