Source code for evorl.networks.layer_norm
1from typing import Optional
2
3import jax
4import jax.numpy as jnp
5from flax import struct, linen as nn
6
7
[docs]
8class StaticLayerNorm(nn.LayerNorm):
9 """Layer normalization with fixed scale and bias."""
10
11 use_bias: bool = False
12 use_scale: bool = False
13 fixed_bias: jax.Array = struct.field(default_factory=lambda: jnp.zeros(()))
14 fixed_scale: jax.Array = struct.field(default_factory=lambda: jnp.ones(()))
15
16 @nn.compact
17 def __call__(self, x, *, mask: Optional[jax.Array] = None):
18 y = super().__call__(x, mask=mask)
19 return y * self.fixed_scale + self.fixed_bias
20
21
[docs]
22def get_norm_layer(norm_layer_type: str) -> type[nn.Module]:
23 """Get the normalization layer class based on the type."""
24 match norm_layer_type:
25 case "layer_norm":
26 norm_layer = nn.LayerNorm
27 case "static_layer_norm":
28 norm_layer = StaticLayerNorm
29 case "none":
30 norm_layer = None
31 case _:
32 raise ValueError(f"Invalid norm_layer_type: {norm_layer_type}")
33
34 return norm_layer