1from abc import ABCMeta, abstractmethod
2from collections.abc import Mapping
3from typing import Any, Protocol
4
5import jax
6import jax.tree_util as jtu
7import chex
8import numpy as np
9
10from evorl.envs import Space, is_leaf_space
11from evorl.sample_batch import SampleBatch
12from evorl.types import (
13 Action,
14 Axis,
15 LossDict,
16 Params,
17 PolicyExtraInfo,
18 PyTreeData,
19 PyTreeNode,
20 PyTreeDict,
21)
22
23
[docs]
24class AgentState(PyTreeData):
25 """State of the agent.
26
27 Attributes:
28 params: The network parameters of the agent.
29 obs_preprocessor_state: The state of the observation preprocessor.
30 action_postprocessor_state: The state of the action postprocessor.
31 extra_state: Extra state of the agent.
32 """
33
34 params: Mapping[str, Params]
35 obs_preprocessor_state: Any = None
36 # TODO: define the action_postprocessor_state
37 action_postprocessor_state: Any = None
38 extra_state: Any = None
39
40
41AgentStateAxis = AgentState | Axis
42
43
[docs]
44class ObsPreprocessorFn(Protocol):
45 """The type of the observation preprocessor function."""
46
47 def __call__(self, obs: chex.Array, *args: Any, **kwds: Any) -> chex.Array:
48 return obs
49
50
[docs]
51class LossFn(Protocol):
52 """The type of the agent's loss function.
53
54 In some case, a single loss function is not enough. For example, DDPG has two loss functions: actor_loss and critic_loss.
55 """
56
57 def __call__(
58 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
59 ) -> LossDict:
60 pass
61
62
[docs]
63class AgentActionFn(Protocol):
64 """The type of the agent's action function."""
65
66 def __call__(
67 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
68 ) -> tuple[Action, PolicyExtraInfo]:
69 pass
70
71
[docs]
72class Agent(PyTreeNode, metaclass=ABCMeta):
73 """Agent Interface.
74
75 The responsibilities of an Agent:
76
77 - Store models like actor and critic.
78 - Interact with environment by `compute_actions()` or `evaluate_actions()`.
79 - Compute algorithm-specific losses (optional).
80 """
81
[docs]
82 @abstractmethod
83 def init(
84 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
85 ) -> AgentState:
86 pass
87
[docs]
88 @abstractmethod
89 def compute_actions(
90 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
91 ) -> tuple[Action, PolicyExtraInfo]:
92 """Get actions from the policy model + add exploraton noise.
93
94 This method is exclusively used for rollout.
95
96 Args:
97 agent_state: the state of the agent.
98 sample_batch: Previous Transition data. Usually only contrains `obs`.
99 key: JAX PRNGKey.
100
101 Return:
102 A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN.
103 """
104 raise NotImplementedError()
105
[docs]
106 @abstractmethod
107 def evaluate_actions(
108 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
109 ) -> tuple[Action, PolicyExtraInfo]:
110 """Get the best action from the action distribution.
111
112 This method is exclusively used for evaluation.
113
114 Args:
115 agent_state: the state of the agent.
116 sample_batch: Previous Transition data. Usually only contrains `obs`.
117 key: JAX PRNGKey.
118
119 Return:
120 A tuple (action, policy_extra_info), policy_extra_info is a dict containing extra information about the policy, such as the current hidden state of RNN.
121
122 """
123 raise NotImplementedError()
124
125
[docs]
126class RandomAgent(Agent):
127 """An agent that takes uniform random actions."""
128
[docs]
129 def init(
130 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
131 ) -> AgentState:
132 extra_state = PyTreeDict(
133 action_space=action_space,
134 obs_space=obs_space,
135 )
136 return AgentState(params={}, extra_state=extra_state)
137
[docs]
138 def compute_actions(
139 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
140 ) -> tuple[Action, PolicyExtraInfo]:
141 obs_space = agent_state.extra_state.obs_space
142 action_space = agent_state.extra_state.action_space
143
144 _obs = jtu.tree_leaves(sample_batch.obs)[0]
145 _obs_space = jtu.tree_leaves(obs_space, is_leaf=is_leaf_space)[0]
146 batch_shapes = _obs.shape[: -len(_obs_space.shape)]
147
148 chex.assert_tree_shape_prefix(sample_batch.obs, batch_shapes)
149
150 action_sample_fn = action_space.sample
151 for _ in range(len(batch_shapes)):
152 action_sample_fn = jax.vmap(action_sample_fn)
153
154 action_keys = jax.random.split(key, np.prod(batch_shapes)).reshape(
155 *batch_shapes, 2
156 )
157
158 actions = action_sample_fn(action_keys)
159 return actions, PyTreeDict()
160
[docs]
161 def evaluate_actions(
162 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
163 ) -> tuple[Action, PolicyExtraInfo]:
164 return self.compute_actions(agent_state, sample_batch, key)