1from collections.abc import Callable, Sequence
2from typing import Any
3
4import jax
5import jax.numpy as jnp
6from flax import linen as nn
7
8from .spectral_norm import SNDense
9from .layer_norm import get_norm_layer
10
11ActivationFn = Callable[[jax.Array], jax.Array]
12Initializer = Callable[..., Any]
13
14
[docs]
15class MLP(nn.Module):
16 """MLP module."""
17
18 layer_sizes: Sequence[int]
19 activation: ActivationFn = nn.relu
20 kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
21 activation_final: ActivationFn | None = None
22 use_bias: bool = True
23 norm_layer: nn.Module | None = None
24
25 @nn.compact
26 def __call__(self, data: jax.Array):
27 hidden = data
28 for i, hidden_size in enumerate(self.layer_sizes):
29 hidden = nn.Dense(
30 hidden_size,
31 name=f"hidden_{i}",
32 kernel_init=self.kernel_init,
33 use_bias=self.use_bias,
34 )(hidden)
35
36 if i != len(self.layer_sizes) - 1:
37 if self.norm_layer is not None:
38 hidden = self.norm_layer()(hidden)
39
40 hidden = self.activation(hidden)
41 elif self.activation_final is not None:
42 # if self.norm_layer is not None:
43 # hidden = self.norm_layer()(hidden)
44
45 hidden = self.activation_final(hidden)
46
47 return hidden
48
49
[docs]
50class SNMLP(nn.Module):
51 """MLP module with Spectral Normalization."""
52
53 layer_sizes: Sequence[int]
54 activation: ActivationFn = nn.relu
55 kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
56 activation_final: ActivationFn | None = None
57 use_bias: bool = True
58
59 @nn.compact
60 def __call__(self, data: jax.Array):
61 hidden = data
62 for i, hidden_size in enumerate(self.layer_sizes):
63 hidden = SNDense(
64 hidden_size,
65 name=f"hidden_{i}",
66 kernel_init=self.kernel_init,
67 use_bias=self.use_bias,
68 )(hidden)
69
70 if i != len(self.layer_sizes) - 1:
71 hidden = self.activation(hidden)
72 elif self.activation_final is not None:
73 hidden = self.activation_final(hidden)
74 return hidden
75
76
[docs]
77def make_mlp(
78 layer_sizes: Sequence[int],
79 activation: ActivationFn = nn.relu,
80 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
81 activation_final: ActivationFn | None = None,
82 use_bias: bool = True,
83 norm_layer_type: str = "none",
84) -> nn.Module:
85 """Creates an MLP network."""
86 if norm_layer_type == "spectral_norm":
87 mlp = SNMLP(
88 layer_sizes=layer_sizes,
89 activation=activation,
90 kernel_init=kernel_init,
91 activation_final=activation_final,
92 use_bias=use_bias,
93 )
94 else:
95 mlp = MLP(
96 layer_sizes=layer_sizes,
97 activation=activation,
98 kernel_init=kernel_init,
99 activation_final=activation_final,
100 use_bias=use_bias,
101 norm_layer=get_norm_layer(norm_layer_type),
102 )
103
104 return mlp
105
106
[docs]
107def make_vmap_mlp(
108 layer_sizes: Sequence[int],
109 activation: ActivationFn = nn.relu,
110 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
111 activation_final: ActivationFn | None = None,
112 use_bias: bool = True,
113 norm_layer_type: str = "none",
114 out_axes: int = -2,
115):
116 """Creates multiple MLP networks in parallel."""
117 if norm_layer_type == "spectral_norm":
118 mlp = nn.vmap(
119 SNMLP,
120 out_axes=out_axes,
121 variable_axes={"params": 0},
122 split_rngs={"params": True},
123 )(
124 layer_sizes=layer_sizes,
125 activation=activation,
126 kernel_init=kernel_init,
127 activation_final=activation_final,
128 use_bias=use_bias,
129 )
130 else:
131 mlp = nn.vmap(
132 MLP,
133 out_axes=out_axes,
134 variable_axes={"params": 0},
135 split_rngs={"params": True},
136 )(
137 layer_sizes=layer_sizes,
138 activation=activation,
139 kernel_init=kernel_init,
140 activation_final=activation_final,
141 norm_layer=get_norm_layer(norm_layer_type),
142 use_bias=use_bias,
143 )
144
145 return mlp
146
147
[docs]
148def make_policy_network(
149 action_size: int,
150 hidden_layer_sizes: Sequence[int] = (256, 256),
151 use_bias: bool = True,
152 activation: ActivationFn = nn.relu,
153 activation_final: ActivationFn | None = None,
154 norm_layer_type: str = "none",
155 obs_key: str = "",
156) -> nn.Module:
157 """Creates a policy network."""
158
159 class PolicyModule(nn.Module):
160 @nn.compact
161 def __call__(self, obs: jax.Array):
162 if obs_key:
163 obs = obs[obs_key]
164
165 actions = make_mlp(
166 layer_sizes=tuple(hidden_layer_sizes) + (action_size,),
167 activation=activation,
168 kernel_init=jax.nn.initializers.lecun_uniform(),
169 activation_final=activation_final,
170 use_bias=use_bias,
171 norm_layer_type=norm_layer_type,
172 )(obs)
173
174 return actions
175
176 policy_model = PolicyModule()
177
178 return policy_model
179
180
[docs]
181def make_v_network(
182 hidden_layer_sizes: Sequence[int] = (256, 256),
183 activation: ActivationFn = nn.relu,
184 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
185 norm_layer_type: str = "none",
186 obs_key: str = "",
187) -> nn.Module:
188 """Creates a V network: (obs) -> value."""
189
190 class VModule(nn.Module):
191 @nn.compact
192 def __call__(self, obs: jax.Array):
193 if obs_key:
194 obs = obs[obs_key]
195
196 vs = make_mlp(
197 layer_sizes=tuple(hidden_layer_sizes) + (1,),
198 activation=activation,
199 kernel_init=kernel_init,
200 norm_layer_type=norm_layer_type,
201 )(obs)
202
203 return vs.squeeze(-1)
204
205 value_model = VModule()
206
207 return value_model
208
209
[docs]
210def make_q_network(
211 n_stack: int = 1,
212 hidden_layer_sizes: Sequence[int] = (256, 256),
213 activation: ActivationFn = nn.relu,
214 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
215 norm_layer_type: str = "none",
216 obs_key: str = "",
217) -> nn.Module:
218 """Creates a Q network: (obs, action) -> value."""
219
220 class QModule(nn.Module):
221 """Q Module for continuous action space."""
222
223 n: int
224
225 @nn.compact
226 def __call__(self, obs: jax.Array, actions: jax.Array):
227 if obs_key:
228 obs = obs[obs_key]
229
230 hidden = jnp.concatenate([obs, actions], axis=-1)
231 if self.n == 1:
232 qs = make_mlp(
233 layer_sizes=tuple(hidden_layer_sizes) + (1,),
234 activation=activation,
235 kernel_init=kernel_init,
236 norm_layer_type=norm_layer_type,
237 )(hidden)
238 elif self.n > 1:
239 hidden = jnp.broadcast_to(hidden, (self.n,) + hidden.shape)
240 qs = make_vmap_mlp(
241 layer_sizes=tuple(hidden_layer_sizes) + (1,),
242 activation=activation,
243 kernel_init=kernel_init,
244 norm_layer_type=norm_layer_type,
245 )(hidden)
246 else:
247 raise ValueError("n should be greater than 0")
248
249 return qs.squeeze(-1)
250
251 q_module = QModule(n=n_stack)
252
253 return q_module
254
255
[docs]
256def make_discrete_q_network(
257 action_size: int,
258 n_stack: int = 1,
259 hidden_layer_sizes: Sequence[int] = (256, 256),
260 activation: ActivationFn = nn.relu,
261 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
262 norm_layer_type: str = "none",
263 obs_key: str = "",
264) -> nn.Module:
265 """Creates a Q network for discrete action space: (obs) -> q_values."""
266
267 class QModule(nn.Module):
268 """Q Module for discrete action space."""
269
270 n: int
271
272 @nn.compact
273 def __call__(self, obs: jax.Array):
274 if obs_key:
275 obs = obs[obs_key]
276
277 if self.n == 1:
278 qs = make_mlp(
279 layer_sizes=tuple(hidden_layer_sizes) + (action_size,),
280 activation=activation,
281 kernel_init=kernel_init,
282 norm_layer_type=norm_layer_type,
283 )(obs)
284 elif self.n > 1:
285 obs = jnp.broadcast_to(obs, (self.n,) + obs.shape)
286 qs = make_vmap_mlp(
287 layer_sizes=tuple(hidden_layer_sizes) + (action_size,),
288 activation=activation,
289 kernel_init=kernel_init,
290 norm_layer_type=norm_layer_type,
291 )(obs)
292 else:
293 raise ValueError("n should be greater than 0")
294
295 return qs
296
297 q_module = QModule(n=n_stack)
298
299 return q_module