evorl.rollout

Module Contents

Classes

Functions

env_step

Collect one-step data.

eval_env_step

Collect one-step data in evaluation mode.

eval_rollout_episode

Evaulate a batch of episodic trajectories.

fast_eval_rollout_episode

Fast evaulate a batch of episodic trajectories.

rollout

Collect trajectories with length of rollout_length.

API

class evorl.rollout.RolloutFn[source]

Bases: typing.Protocol

evorl.rollout.env_step(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, env_extra_fields: collections.abc.Sequence[str] = ()) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]

Collect one-step data.

evorl.rollout.eval_env_step(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]

Collect one-step data in evaluation mode.

evorl.rollout.eval_rollout_episode(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]

Evaulate a batch of episodic trajectories.

It avoids unnecessary calls of env_step() when all environments are done. However, the agent’s action function will still be called after that. When the function is wrapped by jax.vmap(), this mechanism will not work.

Parameters:
  • env_fnstep function of a vmapped env without autoreset.

  • action_fn – The agent’s action function. Eg: agent.compute_actions.

  • env_state – State of the environment.

  • agent_state – State of the agent.

  • key – PRNG key.

  • rollout_length – The length of the episodes. This value usually keeps the same as the env’s max_episode_steps or be smllar than that.

Returns:

A tuple (trajectory, env_state). - trajectory: SampleBatch with shape (T, B, …), where T=rollout_length, B=#envs. When a episode is terminated - env_state: last env_state after rollout

evorl.rollout.fast_eval_rollout_episode(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int) tuple[evorl.types.PyTreeDict, evorl.envs.EnvState][source]

Fast evaulate a batch of episodic trajectories.

A even faster implementation than eval_rollout_episode(). It achieves early termination when it is not wrapped by jax.vmap(). However, this method does not collect the trajectory data, it only returns the aggregated metrics dict with keys (episode_returns, episode_lengths), which is useful in cases like evaluation.

Parameters:
  • env_fnstep function of a vmapped env without autoreset.

  • action_fn – The agent’s action function. Eg: agent.compute_actions.

  • env_state – State of the environment.

  • agent_state – State of the agent.

  • key – PRNG key.

  • rollout_length – The length of the episodes. This value usually keeps the same as the env’s max_episode_steps.

Returns:

Dict(episode_returns, episode_lengths) env_state: Last env_state after evaluation.

Return type:

metrics

evorl.rollout.rollout(env_fn: evorl.envs.EnvStepFn, action_fn: evorl.agent.AgentActionFn, env_state: evorl.envs.EnvState, agent_state: evorl.agent.AgentState, key: chex.PRNGKey, rollout_length: int, env_extra_fields: collections.abc.Sequence[str] = ()) tuple[evorl.sample_batch.SampleBatch, evorl.envs.EnvState][source]

Collect trajectories with length of rollout_length.

This method is a general rollout method used for collecting trajectories from a vectorized env. When the env enables autoreset, the returned sequential trajactory data could contain segments from multiple episodes.

Parameters:
  • env_fnstep function of a vmapped env.

  • action_fn – The agent’s action function. Eg: agent.compute_actions.

  • env_state – State of the environment.

  • agent_state – State of the agent.

  • key – PRNG key.

  • rollout_length – The length of the trajectory to collect.

  • env_extra_fields – Extra fields collected from env_state.info into trajectory.extras.env_extras.

Returns:

A tuple (trajectory, env_state). - trajectory: SampleBatch object with shape (T, B, …), where T=rollout_length, B=#envs in env_fn. - env_state: last env_state after rollout