evorl.rollout¶
Module Contents¶
Classes¶
Functions¶
Collect one-step data. |
|
Collect one-step data in evaluation mode. |
|
Evaulate a batch of episodic trajectories. |
|
Fast evaulate a batch of episodic trajectories. |
|
Collect trajectories with length of |
API¶
- evorl.rollout.env_step(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, env_extra_fields: collections.abc.Sequence[str] = ()) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]¶
Collect one-step data.
- evorl.rollout.eval_env_step(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]¶
Collect one-step data in evaluation mode.
- evorl.rollout.eval_rollout_episode(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]¶
Evaulate a batch of episodic trajectories.
It avoids unnecessary calls of
env_step()when all environments are done. However, the agent’s action function will still be called after that. When the function is wrapped byjax.vmap(), this mechanism will not work.- Parameters:
env_fn –
stepfunction of a vmapped env without autoreset.action_fn – The agent’s action function. Eg:
agent.compute_actions.env_state – State of the environment.
agent_state – State of the agent.
key – PRNG key.
rollout_length – The length of the episodes. This value usually keeps the same as the env’s
max_episode_stepsor be smllar than that.
- Returns:
A tuple (trajectory, env_state). - trajectory: SampleBatch with shape (T, B, …), where T=rollout_length, B=#envs. When a episode is terminated - env_state: last env_state after rollout
- evorl.rollout.fast_eval_rollout_episode(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int) tuple[evorl.types.PyTreeDict, evorl.envs.EnvState][source]¶
Fast evaulate a batch of episodic trajectories.
A even faster implementation than
eval_rollout_episode(). It achieves early termination when it is not wrapped byjax.vmap(). However, this method does not collect the trajectory data, it only returns the aggregated metrics dict with keys (episode_returns, episode_lengths), which is useful in cases like evaluation.- Parameters:
env_fn –
stepfunction of a vmapped env without autoreset.action_fn – The agent’s action function. Eg:
agent.compute_actions.env_state – State of the environment.
agent_state – State of the agent.
key – PRNG key.
rollout_length – The length of the episodes. This value usually keeps the same as the env’s
max_episode_steps.
- Returns:
Dict(episode_returns, episode_lengths) env_state: Last env_state after evaluation.
- Return type:
metrics
- evorl.rollout.rollout(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int, env_extra_fields: collections.abc.Sequence[str] = ()) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]¶
Collect trajectories with length of
rollout_length.This method is a general rollout method used for collecting trajectories from a vectorized env. When the env enables autoreset, the returned sequential trajactory data could contain segments from multiple episodes.
- Parameters:
env_fn –
stepfunction of a vmapped env.action_fn – The agent’s action function. Eg:
agent.compute_actions.env_state – State of the environment.
agent_state – State of the agent.
key – PRNG key.
rollout_length – The length of the trajectory to collect.
env_extra_fields – Extra fields collected from
env_state.infointotrajectory.extras.env_extras.
- Returns:
A tuple (trajectory, env_state). - trajectory:
SampleBatchobject with shape (T, B, …), where T=rollout_length, B=#envs inenv_fn. - env_state: last env_state after rollout