Source code for evorl.networks.layer_norm

 1from typing import Optional
 2
 3import jax
 4import jax.numpy as jnp
 5from flax import struct, linen as nn
 6
 7
[docs] 8class StaticLayerNorm(nn.LayerNorm): 9 """Layer normalization with fixed scale and bias.""" 10 11 use_bias: bool = False 12 use_scale: bool = False 13 fixed_bias: jax.Array = struct.field(default_factory=lambda: jnp.zeros(())) 14 fixed_scale: jax.Array = struct.field(default_factory=lambda: jnp.ones(())) 15 16 @nn.compact 17 def __call__(self, x, *, mask: Optional[jax.Array] = None): 18 y = super().__call__(x, mask=mask) 19 return y * self.fixed_scale + self.fixed_bias
20 21
[docs] 22def get_norm_layer(norm_layer_type: str) -> type[nn.Module]: 23 """Get the normalization layer class based on the type.""" 24 match norm_layer_type: 25 case "layer_norm": 26 norm_layer = nn.LayerNorm 27 case "static_layer_norm": 28 norm_layer = StaticLayerNorm 29 case "none": 30 norm_layer = None 31 case _: 32 raise ValueError(f"Invalid norm_layer_type: {norm_layer_type}") 33 34 return norm_layer