Source code for evorl.ec.operators.utils

1import jax.tree_util as jtu
2
3
[docs] 4def is_layer_norm_layer(path: tuple[jtu.DictKey]): 5 for p in path: 6 if isinstance(p, jtu.DictKey) and "LayerNorm" in p.key: 7 return True 8 9 return False