Source code for evorl.evaluators.ec_evaluator

 1from collections.abc import Sequence
 2
 3import chex
 4import jax
 5
 6from evorl.agent import AgentState, AgentActionFn
 7from evorl.envs import EnvState, EnvStepFn
 8from evorl.sample_batch import SampleBatch
 9from evorl.types import pytree_field
10from evorl.utils.jax_utils import rng_split
11
12from .episode_collector import EpisodeCollector, RolloutFn
13
14
[docs] 15def env_step( 16 env_fn: EnvStepFn, 17 action_fn: AgentActionFn, 18 env_state: EnvState, 19 agent_state: AgentState, # readonly 20 key: chex.PRNGKey, 21) -> tuple[SampleBatch, EnvState]: 22 # sample_batch: [#envs, ...] 23 sample_batch = SampleBatch(obs=env_state.obs) 24 25 actions, policy_extras = action_fn(agent_state, sample_batch, key) 26 env_nstate = env_fn(env_state, actions) 27 28 transition = SampleBatch( 29 obs=env_state.obs, 30 rewards=env_nstate.reward, 31 dones=env_nstate.done, 32 ) 33 34 return transition, env_nstate
35 36
[docs] 37def rollout( 38 env_fn: EnvStepFn, 39 action_fn: AgentActionFn, 40 env_state: EnvState, 41 agent_state: AgentState, 42 key: chex.PRNGKey, 43 rollout_length: int, 44 env_extra_fields: Sequence[str] = (), 45) -> tuple[SampleBatch, EnvState]: 46 def _one_step_rollout(carry, unused_t): 47 env_state, current_key = carry 48 next_key, current_key = rng_split(current_key, 2) 49 50 # transition: [#envs, ...] 51 transition, env_nstate = env_step( 52 env_fn, 53 action_fn, 54 env_state, 55 agent_state, 56 current_key, 57 ) 58 59 return (env_nstate, next_key), transition 60 61 # trajectory: [T, #envs, ...] 62 (env_state, _), trajectory = jax.lax.scan( 63 _one_step_rollout, (env_state, key), (), length=rollout_length 64 ) 65 66 return trajectory, env_state
67 68
[docs] 69class EpisodeObsCollector(EpisodeCollector): 70 """Streamlined episode collector for observation only.""" 71 72 rollout_fn: RolloutFn = pytree_field(default=rollout, static=True)