evorl.agent

Module Contents

Classes

Agent

Agent Interface.

AgentActionFn

The type of the agent’s action function.

AgentState

State of the agent.

LossFn

The type of the agent’s loss function.

ObsPreprocessorFn

The type of the observation preprocessor function.

RandomAgent

An agent that takes uniform random actions.

Data

API

class evorl.agent.Agent[source]

Bases: evorl.types.PyTreeNode

Agent Interface.

The responsibilities of an Agent:

  • Store models like actor and critic.

  • Interact with environment by compute_actions() or evaluate_actions().

  • Compute algorithm-specific losses (optional).

abstract compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]

Get actions from the policy model + add exploraton noise.

This method is exclusively used for rollout.

Parameters:
  • agent_state – the state of the agent.

  • sample_batch – Previous Transition data. Usually only contrains obs.

  • key – JAX PRNGKey.

Returns:

A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN.

abstract evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]

Get the best action from the action distribution.

This method is exclusively used for evaluation.

Parameters:
  • agent_state – the state of the agent.

  • sample_batch – Previous Transition data. Usually only contrains obs.

  • key – JAX PRNGKey.

Returns:

A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN.

abstract init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]
class evorl.agent.AgentActionFn[source]

Bases: typing.Protocol

The type of the agent’s action function.

class evorl.agent.AgentState[source]

Bases: evorl.types.PyTreeData

State of the agent.

Variables:
  • params – The network parameters of the agent.

  • obs_preprocessor_state – The state of the observation preprocessor.

  • action_postprocessor_state – The state of the action postprocessor.

  • extra_state – Extra state of the agent.

action_postprocessor_state: Any

None

extra_state: Any

None

obs_preprocessor_state: Any

None

params: collections.abc.Mapping[str, evorl.types.Params]

None

evorl.agent.AgentStateAxis

None

class evorl.agent.LossFn[source]

Bases: typing.Protocol

The type of the agent’s loss function.

In some case, a single loss function is not enough. For example, DDPG has two loss functions: actor_loss and critic_loss.

class evorl.agent.ObsPreprocessorFn[source]

Bases: typing.Protocol

The type of the observation preprocessor function.

class evorl.agent.RandomAgent[source]

Bases: evorl.agent.Agent

An agent that takes uniform random actions.

compute_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
evaluate_actions(agent_state: evorl.agent.AgentState, sample_batch: evorl.sample_batch.SampleBatch, key: chex.PRNGKey) tuple[evorl.types.Action, evorl.types.PolicyExtraInfo][source]
init(obs_space: evorl.envs.Space, action_space: evorl.envs.Space, key: chex.PRNGKey) evorl.agent.AgentState[source]