Source code for evorl.ec.operators.utils
1import jax.tree_util as jtu
2
3
[docs]
4def is_layer_norm_layer(path: tuple[jtu.DictKey]):
5 for p in path:
6 if isinstance(p, jtu.DictKey) and "LayerNorm" in p.key:
7 return True
8
9 return False