Source code for evorl.agent

  1from abc import ABCMeta, abstractmethod
  2from collections.abc import Mapping
  3from typing import Any, Protocol
  4
  5import jax
  6import jax.tree_util as jtu
  7import chex
  8import numpy as np
  9
 10from evorl.envs import Space, is_leaf_space
 11from evorl.sample_batch import SampleBatch
 12from evorl.types import (
 13    Action,
 14    Axis,
 15    LossDict,
 16    Params,
 17    PolicyExtraInfo,
 18    PyTreeData,
 19    PyTreeNode,
 20    PyTreeDict,
 21)
 22
 23
[docs] 24class AgentState(PyTreeData): 25 """State of the agent. 26 27 Attributes: 28 params: The network parameters of the agent. 29 obs_preprocessor_state: The state of the observation preprocessor. 30 action_postprocessor_state: The state of the action postprocessor. 31 extra_state: Extra state of the agent. 32 """ 33 34 params: Mapping[str, Params] 35 obs_preprocessor_state: Any = None 36 # TODO: define the action_postprocessor_state 37 action_postprocessor_state: Any = None 38 extra_state: Any = None
39 40 41AgentStateAxis = AgentState | Axis 42 43
[docs] 44class ObsPreprocessorFn(Protocol): 45 """The type of the observation preprocessor function.""" 46 47 def __call__(self, obs: chex.Array, *args: Any, **kwds: Any) -> chex.Array: 48 return obs
49 50
[docs] 51class LossFn(Protocol): 52 """The type of the agent's loss function. 53 54 In some case, a single loss function is not enough. For example, DDPG has two loss functions: actor_loss and critic_loss. 55 """ 56 57 def __call__( 58 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 59 ) -> LossDict: 60 pass
61 62
[docs] 63class AgentActionFn(Protocol): 64 """The type of the agent's action function.""" 65 66 def __call__( 67 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 68 ) -> tuple[Action, PolicyExtraInfo]: 69 pass
70 71
[docs] 72class Agent(PyTreeNode, metaclass=ABCMeta): 73 """Agent Interface. 74 75 The responsibilities of an Agent: 76 77 - Store models like actor and critic. 78 - Interact with environment by `compute_actions()` or `evaluate_actions()`. 79 - Compute algorithm-specific losses (optional). 80 """ 81
[docs] 82 @abstractmethod 83 def init( 84 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 85 ) -> AgentState: 86 pass
87
[docs] 88 @abstractmethod 89 def compute_actions( 90 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 91 ) -> tuple[Action, PolicyExtraInfo]: 92 """Get actions from the policy model + add exploraton noise. 93 94 This method is exclusively used for rollout. 95 96 Args: 97 agent_state: the state of the agent. 98 sample_batch: Previous Transition data. Usually only contrains `obs`. 99 key: JAX PRNGKey. 100 101 Return: 102 A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN. 103 """ 104 raise NotImplementedError()
105
[docs] 106 @abstractmethod 107 def evaluate_actions( 108 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 109 ) -> tuple[Action, PolicyExtraInfo]: 110 """Get the best action from the action distribution. 111 112 This method is exclusively used for evaluation. 113 114 Args: 115 agent_state: the state of the agent. 116 sample_batch: Previous Transition data. Usually only contrains `obs`. 117 key: JAX PRNGKey. 118 119 Return: 120 A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN. 121 122 """ 123 raise NotImplementedError()
124 125
[docs] 126class RandomAgent(Agent): 127 """An agent that takes uniform random actions.""" 128
[docs] 129 def init( 130 self, obs_space: Space, action_space: Space, key: chex.PRNGKey 131 ) -> AgentState: 132 extra_state = PyTreeDict( 133 action_space=action_space, 134 obs_space=obs_space, 135 ) 136 return AgentState(params={}, extra_state=extra_state)
137
[docs] 138 def compute_actions( 139 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 140 ) -> tuple[Action, PolicyExtraInfo]: 141 obs_space = agent_state.extra_state.obs_space 142 action_space = agent_state.extra_state.action_space 143 144 _obs = jtu.tree_leaves(sample_batch.obs)[0] 145 _obs_space = jtu.tree_leaves(obs_space, is_leaf=is_leaf_space)[0] 146 batch_shapes = _obs.shape[: -len(_obs_space.shape)] 147 148 chex.assert_tree_shape_prefix(sample_batch.obs, batch_shapes) 149 150 action_sample_fn = action_space.sample 151 for _ in range(len(batch_shapes)): 152 action_sample_fn = jax.vmap(action_sample_fn) 153 154 action_keys = jax.random.split(key, np.prod(batch_shapes)).reshape( 155 *batch_shapes, 2 156 ) 157 158 actions = action_sample_fn(action_keys) 159 return actions, PyTreeDict()
160
[docs] 161 def evaluate_actions( 162 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey 163 ) -> tuple[Action, PolicyExtraInfo]: 164 return self.compute_actions(agent_state, sample_batch, key)