1from collections.abc import Sequence
2
3import chex
4import jax
5
6from evorl.agent import AgentState, AgentActionFn
7from evorl.envs import EnvState, EnvStepFn
8from evorl.sample_batch import SampleBatch
9from evorl.types import pytree_field
10from evorl.utils.jax_utils import rng_split
11
12from .episode_collector import EpisodeCollector, RolloutFn
13
14
[docs]
15def env_step(
16 env_fn: EnvStepFn,
17 action_fn: AgentActionFn,
18 env_state: EnvState,
19 agent_state: AgentState, # readonly
20 key: chex.PRNGKey,
21) -> tuple[SampleBatch, EnvState]:
22 # sample_batch: [#envs, ...]
23 sample_batch = SampleBatch(obs=env_state.obs)
24
25 actions, policy_extras = action_fn(agent_state, sample_batch, key)
26 env_nstate = env_fn(env_state, actions)
27
28 transition = SampleBatch(
29 obs=env_state.obs,
30 rewards=env_nstate.reward,
31 dones=env_nstate.done,
32 )
33
34 return transition, env_nstate
35
36
[docs]
37def rollout(
38 env_fn: EnvStepFn,
39 action_fn: AgentActionFn,
40 env_state: EnvState,
41 agent_state: AgentState,
42 key: chex.PRNGKey,
43 rollout_length: int,
44 env_extra_fields: Sequence[str] = (),
45) -> tuple[SampleBatch, EnvState]:
46 def _one_step_rollout(carry, unused_t):
47 env_state, current_key = carry
48 next_key, current_key = rng_split(current_key, 2)
49
50 # transition: [#envs, ...]
51 transition, env_nstate = env_step(
52 env_fn,
53 action_fn,
54 env_state,
55 agent_state,
56 current_key,
57 )
58
59 return (env_nstate, next_key), transition
60
61 # trajectory: [T, #envs, ...]
62 (env_state, _), trajectory = jax.lax.scan(
63 _one_step_rollout, (env_state, key), (), length=rollout_length
64 )
65
66 return trajectory, env_state
67
68
[docs]
69class EpisodeObsCollector(EpisodeCollector):
70 """Streamlined episode collector for observation only."""
71
72 rollout_fn: RolloutFn = pytree_field(default=rollout, static=True)