Source code for evorl.networks.linear

  1from collections.abc import Callable, Sequence
  2from typing import Any
  3
  4import jax
  5import jax.numpy as jnp
  6from flax import linen as nn
  7
  8from .spectral_norm import SNDense
  9from .layer_norm import get_norm_layer
 10
 11ActivationFn = Callable[[jax.Array], jax.Array]
 12Initializer = Callable[..., Any]
 13
 14
[docs] 15class MLP(nn.Module): 16 """MLP module.""" 17 18 layer_sizes: Sequence[int] 19 activation: ActivationFn = nn.relu 20 kernel_init: Initializer = jax.nn.initializers.lecun_uniform() 21 activation_final: ActivationFn | None = None 22 use_bias: bool = True 23 norm_layer: nn.Module | None = None 24 25 @nn.compact 26 def __call__(self, data: jax.Array): 27 hidden = data 28 for i, hidden_size in enumerate(self.layer_sizes): 29 hidden = nn.Dense( 30 hidden_size, 31 name=f"hidden_{i}", 32 kernel_init=self.kernel_init, 33 use_bias=self.use_bias, 34 )(hidden) 35 36 if i != len(self.layer_sizes) - 1: 37 if self.norm_layer is not None: 38 hidden = self.norm_layer()(hidden) 39 40 hidden = self.activation(hidden) 41 elif self.activation_final is not None: 42 # if self.norm_layer is not None: 43 # hidden = self.norm_layer()(hidden) 44 45 hidden = self.activation_final(hidden) 46 47 return hidden
48 49
[docs] 50class SNMLP(nn.Module): 51 """MLP module with Spectral Normalization.""" 52 53 layer_sizes: Sequence[int] 54 activation: ActivationFn = nn.relu 55 kernel_init: Initializer = jax.nn.initializers.lecun_uniform() 56 activation_final: ActivationFn | None = None 57 use_bias: bool = True 58 59 @nn.compact 60 def __call__(self, data: jax.Array): 61 hidden = data 62 for i, hidden_size in enumerate(self.layer_sizes): 63 hidden = SNDense( 64 hidden_size, 65 name=f"hidden_{i}", 66 kernel_init=self.kernel_init, 67 use_bias=self.use_bias, 68 )(hidden) 69 70 if i != len(self.layer_sizes) - 1: 71 hidden = self.activation(hidden) 72 elif self.activation_final is not None: 73 hidden = self.activation_final(hidden) 74 return hidden
75 76
[docs] 77def make_mlp( 78 layer_sizes: Sequence[int], 79 activation: ActivationFn = nn.relu, 80 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), 81 activation_final: ActivationFn | None = None, 82 use_bias: bool = True, 83 norm_layer_type: str = "none", 84) -> nn.Module: 85 """Creates an MLP network.""" 86 if norm_layer_type == "spectral_norm": 87 mlp = SNMLP( 88 layer_sizes=layer_sizes, 89 activation=activation, 90 kernel_init=kernel_init, 91 activation_final=activation_final, 92 use_bias=use_bias, 93 ) 94 else: 95 mlp = MLP( 96 layer_sizes=layer_sizes, 97 activation=activation, 98 kernel_init=kernel_init, 99 activation_final=activation_final, 100 use_bias=use_bias, 101 norm_layer=get_norm_layer(norm_layer_type), 102 ) 103 104 return mlp
105 106
[docs] 107def make_vmap_mlp( 108 layer_sizes: Sequence[int], 109 activation: ActivationFn = nn.relu, 110 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), 111 activation_final: ActivationFn | None = None, 112 use_bias: bool = True, 113 norm_layer_type: str = "none", 114 out_axes: int = -2, 115): 116 """Creates multiple MLP networks in parallel.""" 117 if norm_layer_type == "spectral_norm": 118 mlp = nn.vmap( 119 SNMLP, 120 out_axes=out_axes, 121 variable_axes={"params": 0}, 122 split_rngs={"params": True}, 123 )( 124 layer_sizes=layer_sizes, 125 activation=activation, 126 kernel_init=kernel_init, 127 activation_final=activation_final, 128 use_bias=use_bias, 129 ) 130 else: 131 mlp = nn.vmap( 132 MLP, 133 out_axes=out_axes, 134 variable_axes={"params": 0}, 135 split_rngs={"params": True}, 136 )( 137 layer_sizes=layer_sizes, 138 activation=activation, 139 kernel_init=kernel_init, 140 activation_final=activation_final, 141 norm_layer=get_norm_layer(norm_layer_type), 142 use_bias=use_bias, 143 ) 144 145 return mlp
146 147
[docs] 148def make_policy_network( 149 action_size: int, 150 hidden_layer_sizes: Sequence[int] = (256, 256), 151 use_bias: bool = True, 152 activation: ActivationFn = nn.relu, 153 activation_final: ActivationFn | None = None, 154 norm_layer_type: str = "none", 155 obs_key: str = "", 156) -> nn.Module: 157 """Creates a policy network.""" 158 159 class PolicyModule(nn.Module): 160 @nn.compact 161 def __call__(self, obs: jax.Array): 162 if obs_key: 163 obs = obs[obs_key] 164 165 actions = make_mlp( 166 layer_sizes=tuple(hidden_layer_sizes) + (action_size,), 167 activation=activation, 168 kernel_init=jax.nn.initializers.lecun_uniform(), 169 activation_final=activation_final, 170 use_bias=use_bias, 171 norm_layer_type=norm_layer_type, 172 )(obs) 173 174 return actions 175 176 policy_model = PolicyModule() 177 178 return policy_model
179 180
[docs] 181def make_v_network( 182 hidden_layer_sizes: Sequence[int] = (256, 256), 183 activation: ActivationFn = nn.relu, 184 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), 185 norm_layer_type: str = "none", 186 obs_key: str = "", 187) -> nn.Module: 188 """Creates a V network: (obs) -> value.""" 189 190 class VModule(nn.Module): 191 @nn.compact 192 def __call__(self, obs: jax.Array): 193 if obs_key: 194 obs = obs[obs_key] 195 196 vs = make_mlp( 197 layer_sizes=tuple(hidden_layer_sizes) + (1,), 198 activation=activation, 199 kernel_init=kernel_init, 200 norm_layer_type=norm_layer_type, 201 )(obs) 202 203 return vs.squeeze(-1) 204 205 value_model = VModule() 206 207 return value_model
208 209
[docs] 210def make_q_network( 211 n_stack: int = 1, 212 hidden_layer_sizes: Sequence[int] = (256, 256), 213 activation: ActivationFn = nn.relu, 214 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), 215 norm_layer_type: str = "none", 216 obs_key: str = "", 217) -> nn.Module: 218 """Creates a Q network: (obs, action) -> value.""" 219 220 class QModule(nn.Module): 221 """Q Module for continuous action space.""" 222 223 n: int 224 225 @nn.compact 226 def __call__(self, obs: jax.Array, actions: jax.Array): 227 if obs_key: 228 obs = obs[obs_key] 229 230 hidden = jnp.concatenate([obs, actions], axis=-1) 231 if self.n == 1: 232 qs = make_mlp( 233 layer_sizes=tuple(hidden_layer_sizes) + (1,), 234 activation=activation, 235 kernel_init=kernel_init, 236 norm_layer_type=norm_layer_type, 237 )(hidden) 238 elif self.n > 1: 239 hidden = jnp.broadcast_to(hidden, (self.n,) + hidden.shape) 240 qs = make_vmap_mlp( 241 layer_sizes=tuple(hidden_layer_sizes) + (1,), 242 activation=activation, 243 kernel_init=kernel_init, 244 norm_layer_type=norm_layer_type, 245 )(hidden) 246 else: 247 raise ValueError("n should be greater than 0") 248 249 return qs.squeeze(-1) 250 251 q_module = QModule(n=n_stack) 252 253 return q_module
254 255
[docs] 256def make_discrete_q_network( 257 action_size: int, 258 n_stack: int = 1, 259 hidden_layer_sizes: Sequence[int] = (256, 256), 260 activation: ActivationFn = nn.relu, 261 kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), 262 norm_layer_type: str = "none", 263 obs_key: str = "", 264) -> nn.Module: 265 """Creates a Q network for discrete action space: (obs) -> q_values.""" 266 267 class QModule(nn.Module): 268 """Q Module for discrete action space.""" 269 270 n: int 271 272 @nn.compact 273 def __call__(self, obs: jax.Array): 274 if obs_key: 275 obs = obs[obs_key] 276 277 if self.n == 1: 278 qs = make_mlp( 279 layer_sizes=tuple(hidden_layer_sizes) + (action_size,), 280 activation=activation, 281 kernel_init=kernel_init, 282 norm_layer_type=norm_layer_type, 283 )(obs) 284 elif self.n > 1: 285 obs = jnp.broadcast_to(obs, (self.n,) + obs.shape) 286 qs = make_vmap_mlp( 287 layer_sizes=tuple(hidden_layer_sizes) + (action_size,), 288 activation=activation, 289 kernel_init=kernel_init, 290 norm_layer_type=norm_layer_type, 291 )(obs) 292 else: 293 raise ValueError("n should be greater than 0") 294 295 return qs 296 297 q_module = QModule(n=n_stack) 298 299 return q_module