1# Copyright 2023 The Brax Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Flax-style Dense module with Spectral Normalization.
16
17From https://github.com/google/brax/blob/main/brax/training/networks.py
18
19Reference:
20 Dense: https://github.com/google/flax/blob/main/flax/linen/linear.py
21 Spectral Normalization:
22 - https://arxiv.org/abs/1802.05957
23 - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py
24"""
25
26from collections.abc import Callable
27from typing import Any
28
29import jax.numpy as jnp
30from brax.training.types import PRNGKey
31from flax import linen
32from flax.linen.initializers import lecun_normal, normal, zeros
33from jax import lax
34
35Array = Any
36Shape = tuple[int]
37Dtype = Any
38
39
40def _l2_normalize(x, axis=None, eps=1e-12):
41 """Normalizes along dimension `axis` using an L2 norm.
42
43 This specialized function exists for numerical stability reasons.
44
45 Args:
46 x: An input ndarray.
47 axis: Dimension along which to normalize, e.g. `1` to separately normalize
48 vectors in a batch. Passing `None` views `t` as a flattened vector when
49 calculating the norm (equivalent to Frobenius norm).
50 eps: Epsilon to avoid dividing by zero.
51
52 Returns:
53 An array of the same shape as 'x' L2-normalized along 'axis'.
54 """
55 return x * lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)
56
57
[docs]
58class SNDense(linen.Module):
59 """Dense Spectral Normalization.
60
61 A linear transformation applied over the last dimension of the input
62 with spectral normalization (https://arxiv.org/abs/1802.05957).
63
64 Attributes:
65 features: the number of output features.
66 use_bias: whether to add a bias to the output (default: True).
67 dtype: the dtype of the computation (default: float32).
68 precision: numerical precision of the computation see `jax.lax.Precision` for details.
69 kernel_init: initializer function for the weight matrix.
70 bias_init: initializer function for the bias.
71 eps: The constant used for numerical stability.
72 n_steps: How many steps of power iteration to perform to approximate the singular value of the input.
73 """
74
75 features: int
76 use_bias: bool = True
77 dtype: Any = jnp.float32
78 precision: Any = None
79 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = lecun_normal()
80 bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
81 eps: float = 1e-4
82 n_steps: int = 1
83
84 @linen.compact
85 def __call__(self, inputs: Array) -> Array:
86 """Applies a linear transformation to the inputs along the last dimension.
87
88 Args:
89 inputs: The nd-array to be transformed.
90
91 Returns:
92 The transformed input.
93 """
94 inputs = jnp.asarray(inputs, self.dtype)
95 kernel = self.param(
96 "kernel", self.kernel_init, (inputs.shape[-1], self.features)
97 )
98 kernel = jnp.asarray(kernel, self.dtype)
99
100 kernel_shape = kernel.shape
101 # Handle scalars.
102 if kernel.ndim <= 1:
103 raise ValueError(
104 "Spectral normalization is not well defined for scalar inputs."
105 )
106 # Handle higher-order tensors.
107 elif kernel.ndim > 2:
108 kernel = jnp.reshape(kernel, [-1, kernel.shape[-1]])
109 key = self.make_rng("sing_vec")
110 u0_state = self.variable(
111 "sing_vec", "u0", normal(stddev=1.0), key, (1, kernel.shape[-1])
112 )
113 u0 = u0_state.value
114
115 # Power iteration for the weight's singular value.
116 for _ in range(self.n_steps):
117 v0 = _l2_normalize(jnp.matmul(u0, kernel.transpose([1, 0])), eps=self.eps)
118 u0 = _l2_normalize(jnp.matmul(v0, kernel), eps=self.eps)
119
120 u0 = lax.stop_gradient(u0)
121 v0 = lax.stop_gradient(v0)
122
123 sigma = jnp.matmul(jnp.matmul(v0, kernel), jnp.transpose(u0))[0, 0]
124
125 kernel /= sigma
126 kernel = kernel.reshape(kernel_shape)
127
128 u0_state.value = u0
129
130 y = lax.dot_general(
131 inputs,
132 kernel,
133 (((inputs.ndim - 1,), (0,)), ((), ())),
134 precision=self.precision,
135 )
136 if self.use_bias:
137 bias = self.param("bias", self.bias_init, (self.features,))
138 bias = jnp.asarray(bias, self.dtype)
139 y = y + bias
140 return y