1import logging
2from typing import Any
3
4import chex
5import flax.linen as nn
6import jax.numpy as jnp
7import jax.tree_util as jtu
8
9from evorl.distribution import get_categorical_dist, get_tanh_norm_dist
10from evorl.networks import make_policy_network
11from evorl.sample_batch import SampleBatch
12from evorl.types import (
13 Action,
14 Params,
15 PolicyExtraInfo,
16 PyTreeDict,
17 pytree_field,
18 PyTreeData,
19)
20from evorl.utils import running_statistics
21from evorl.utils.jax_utils import tree_get
22from evorl.envs import Space, Box, Discrete
23
24from evorl.agent import Agent, AgentState
25
26logger = logging.getLogger(__name__)
27
28
[docs]
29class ECNetworkParams(PyTreeData):
30 """Contains training state for the learner."""
31
32 policy_params: Params
33
34
[docs]
35class StochasticECAgent(Agent):
36 """Stochastic Agent.
37
38 Support continuous action space in [-1, 1] via TanhNormal distribution or discrete action space via Softmax distribution.
39 """
40
41 continuous_action: bool
42 policy_network: nn.Module
43 obs_preprocessor: Any = pytree_field(default=None, static=True)
44
45 @property
46 def normalize_obs(self):
47 return self.obs_preprocessor is not None
48
[docs]
49 def init(
50 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
51 ) -> AgentState:
52 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
53 policy_params = self.policy_network.init(key, dummy_obs)
54
55 params_state = ECNetworkParams(
56 policy_params=policy_params,
57 )
58
59 if self.normalize_obs:
60 # Note: statistics are broadcasted to [T*B]
61 obs_preprocessor_state = running_statistics.init_state(
62 tree_get(dummy_obs, 0)
63 )
64 else:
65 obs_preprocessor_state = None
66
67 return AgentState(
68 params=params_state, obs_preprocessor_state=obs_preprocessor_state
69 )
70
[docs]
71 def compute_actions(
72 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
73 ) -> tuple[Action, PolicyExtraInfo]:
74 obs = sample_batch.obs
75 if self.normalize_obs:
76 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
77
78 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
79
80 if self.continuous_action:
81 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
82 else:
83 actions_dist = get_categorical_dist(raw_actions)
84
85 actions = actions_dist.sample(seed=key)
86
87 policy_extras = PyTreeDict(
88 # raw_action=raw_actions,
89 # logp=actions_dist.log_prob(actions)
90 )
91
92 return actions, policy_extras
93
[docs]
94 def evaluate_actions(
95 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
96 ) -> tuple[Action, PolicyExtraInfo]:
97 obs = sample_batch.obs
98 if self.normalize_obs:
99 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
100
101 raw_actions = self.policy_network.apply(agent_state.params.policy_params, obs)
102
103 if self.continuous_action:
104 actions_dist = get_tanh_norm_dist(*jnp.split(raw_actions, 2, axis=-1))
105 else:
106 actions_dist = get_categorical_dist(raw_actions)
107
108 actions = actions_dist.mode()
109
110 return actions, PyTreeDict()
111
112
[docs]
113class DeterministicECAgent(Agent):
114 """Deterministic Agent for continuous action space in [-1, 1]."""
115
116 policy_network: nn.Module
117 obs_preprocessor: Any = pytree_field(default=None, static=True)
118
119 @property
120 def normalize_obs(self):
121 return self.obs_preprocessor is not None
122
[docs]
123 def init(
124 self, obs_space: Space, action_space: Space, key: chex.PRNGKey
125 ) -> AgentState:
126 dummy_obs = jtu.tree_map(lambda x: x[None, ...], obs_space.sample(key))
127 policy_params = self.policy_network.init(key, dummy_obs)
128
129 params_state = ECNetworkParams(
130 policy_params=policy_params,
131 )
132
133 if self.normalize_obs:
134 # Note: statistics are broadcasted to [T*B]
135 obs_preprocessor_state = running_statistics.init_state(
136 tree_get(dummy_obs, 0)
137 )
138 else:
139 obs_preprocessor_state = None
140
141 return AgentState(
142 params=params_state, obs_preprocessor_state=obs_preprocessor_state
143 )
144
[docs]
145 def compute_actions(
146 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
147 ) -> tuple[Action, PolicyExtraInfo]:
148 obs = sample_batch.obs
149 if self.normalize_obs:
150 obs = self.obs_preprocessor(obs, agent_state.obs_preprocessor_state)
151
152 actions = self.policy_network.apply(agent_state.params.policy_params, obs)
153
154 return actions, PyTreeDict()
155
[docs]
156 def evaluate_actions(
157 self, agent_state: AgentState, sample_batch: SampleBatch, key: chex.PRNGKey
158 ) -> tuple[Action, PolicyExtraInfo]:
159 return self.compute_actions(agent_state, sample_batch, key)
160
161
[docs]
162def make_stochastic_ec_agent(
163 action_space: Space,
164 actor_hidden_layer_sizes: tuple[int] = (256, 256),
165 use_bias: bool = True,
166 norm_layer_type: str = "none",
167 normalize_obs: bool = False,
168 policy_obs_key: str = "",
169):
170 if isinstance(action_space, Box):
171 action_size = action_space.shape[0] * 2
172 continuous_action = True
173 elif isinstance(action_space, Discrete):
174 action_size = action_space.n
175 continuous_action = False
176 else:
177 raise NotImplementedError(f"Unsupported action space: {action_space}")
178
179 policy_network = make_policy_network(
180 action_size=action_size,
181 hidden_layer_sizes=actor_hidden_layer_sizes,
182 norm_layer_type=norm_layer_type,
183 use_bias=use_bias,
184 obs_key=policy_obs_key,
185 )
186
187 if normalize_obs:
188 obs_preprocessor = running_statistics.normalize
189 else:
190 obs_preprocessor = None
191
192 return StochasticECAgent(
193 continuous_action=continuous_action,
194 policy_network=policy_network,
195 obs_preprocessor=obs_preprocessor,
196 )
197
198
[docs]
199def make_deterministic_ec_agent(
200 action_space: Space,
201 actor_hidden_layer_sizes: tuple[int] = (256, 256),
202 use_bias: bool = True,
203 norm_layer_type: str = "none",
204 normalize_obs: bool = False,
205 policy_obs_key: str = "",
206):
207 assert isinstance(action_space, Box), "Only continue action space is supported."
208
209 action_size = action_space.shape[0]
210
211 policy_network = make_policy_network(
212 action_size=action_size,
213 hidden_layer_sizes=actor_hidden_layer_sizes,
214 use_bias=use_bias,
215 activation_final=nn.tanh,
216 norm_layer_type=norm_layer_type,
217 obs_key=policy_obs_key,
218 )
219
220 if normalize_obs:
221 obs_preprocessor = running_statistics.normalize
222 else:
223 obs_preprocessor = None
224
225 return DeterministicECAgent(
226 policy_network=policy_network,
227 obs_preprocessor=obs_preprocessor,
228 )