Source code for evorl.networks.spectral_norm

  1# Copyright 2023 The Brax Authors.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#         http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14
 15"""Flax-style Dense module with Spectral Normalization.
 16
 17From https://github.com/google/brax/blob/main/brax/training/networks.py
 18
 19Reference:
 20    Dense: https://github.com/google/flax/blob/main/flax/linen/linear.py
 21    Spectral Normalization:
 22    - https://arxiv.org/abs/1802.05957
 23    - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py
 24"""
 25
 26from collections.abc import Callable
 27from typing import Any
 28
 29import jax.numpy as jnp
 30from brax.training.types import PRNGKey
 31from flax import linen
 32from flax.linen.initializers import lecun_normal, normal, zeros
 33from jax import lax
 34
 35Array = Any
 36Shape = tuple[int]
 37Dtype = Any
 38
 39
 40def _l2_normalize(x, axis=None, eps=1e-12):
 41    """Normalizes along dimension `axis` using an L2 norm.
 42
 43    This specialized function exists for numerical stability reasons.
 44
 45    Args:
 46        x: An input ndarray.
 47        axis: Dimension along which to normalize, e.g. `1` to separately normalize
 48            vectors in a batch. Passing `None` views `t` as a flattened vector when
 49            calculating the norm (equivalent to Frobenius norm).
 50        eps: Epsilon to avoid dividing by zero.
 51
 52    Returns:
 53        An array of the same shape as 'x' L2-normalized along 'axis'.
 54    """
 55    return x * lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)
 56
 57
[docs] 58class SNDense(linen.Module): 59 """Dense Spectral Normalization. 60 61 A linear transformation applied over the last dimension of the input 62 with spectral normalization (https://arxiv.org/abs/1802.05957). 63 64 Attributes: 65 features: the number of output features. 66 use_bias: whether to add a bias to the output (default: True). 67 dtype: the dtype of the computation (default: float32). 68 precision: numerical precision of the computation see `jax.lax.Precision` for details. 69 kernel_init: initializer function for the weight matrix. 70 bias_init: initializer function for the bias. 71 eps: The constant used for numerical stability. 72 n_steps: How many steps of power iteration to perform to approximate the singular value of the input. 73 """ 74 75 features: int 76 use_bias: bool = True 77 dtype: Any = jnp.float32 78 precision: Any = None 79 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = lecun_normal() 80 bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros 81 eps: float = 1e-4 82 n_steps: int = 1 83 84 @linen.compact 85 def __call__(self, inputs: Array) -> Array: 86 """Applies a linear transformation to the inputs along the last dimension. 87 88 Args: 89 inputs: The nd-array to be transformed. 90 91 Returns: 92 The transformed input. 93 """ 94 inputs = jnp.asarray(inputs, self.dtype) 95 kernel = self.param( 96 "kernel", self.kernel_init, (inputs.shape[-1], self.features) 97 ) 98 kernel = jnp.asarray(kernel, self.dtype) 99 100 kernel_shape = kernel.shape 101 # Handle scalars. 102 if kernel.ndim <= 1: 103 raise ValueError( 104 "Spectral normalization is not well defined for scalar inputs." 105 ) 106 # Handle higher-order tensors. 107 elif kernel.ndim > 2: 108 kernel = jnp.reshape(kernel, [-1, kernel.shape[-1]]) 109 key = self.make_rng("sing_vec") 110 u0_state = self.variable( 111 "sing_vec", "u0", normal(stddev=1.0), key, (1, kernel.shape[-1]) 112 ) 113 u0 = u0_state.value 114 115 # Power iteration for the weight's singular value. 116 for _ in range(self.n_steps): 117 v0 = _l2_normalize(jnp.matmul(u0, kernel.transpose([1, 0])), eps=self.eps) 118 u0 = _l2_normalize(jnp.matmul(v0, kernel), eps=self.eps) 119 120 u0 = lax.stop_gradient(u0) 121 v0 = lax.stop_gradient(v0) 122 123 sigma = jnp.matmul(jnp.matmul(v0, kernel), jnp.transpose(u0))[0, 0] 124 125 kernel /= sigma 126 kernel = kernel.reshape(kernel_shape) 127 128 u0_state.value = u0 129 130 y = lax.dot_general( 131 inputs, 132 kernel, 133 (((inputs.ndim - 1,), (0,)), ((), ())), 134 precision=self.precision, 135 ) 136 if self.use_bias: 137 bias = self.param("bias", self.bias_init, (self.features,)) 138 bias = jnp.asarray(bias, self.dtype) 139 y = y + bias 140 return y