evorl.networks.spectral_norm

Flax-style Dense module with Spectral Normalization.

From https://github.com/google/brax/blob/main/brax/training/networks.py

Reference: Dense: https://github.com/google/flax/blob/main/flax/linen/linear.py Spectral Normalization: - https://arxiv.org/abs/1802.05957 - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py

Module Contents

Classes

SNDense

Dense Spectral Normalization.

Data

API

evorl.networks.spectral_norm.Array

None

evorl.networks.spectral_norm.Dtype

None

class evorl.networks.spectral_norm.SNDense[source]

Bases: flax.linen.Module

Dense Spectral Normalization.

A linear transformation applied over the last dimension of the input with spectral normalization (https://arxiv.org/abs/1802.05957).

Variables:
  • features – the number of output features.

  • use_bias – whether to add a bias to the output (default: True).

  • dtype – the dtype of the computation (default: float32).

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • kernel_init – initializer function for the weight matrix.

  • bias_init – initializer function for the bias.

  • eps – The constant used for numerical stability.

  • n_steps – How many steps of power iteration to perform to approximate the singular value of the input.

bias_init: collections.abc.Callable[[brax.training.types.PRNGKey, evorl.networks.spectral_norm.Shape, evorl.networks.spectral_norm.Dtype], evorl.networks.spectral_norm.Array]

None

dtype: Any

None

eps: float

0.0001

features: int

None

kernel_init: collections.abc.Callable[[brax.training.types.PRNGKey, evorl.networks.spectral_norm.Shape, evorl.networks.spectral_norm.Dtype], evorl.networks.spectral_norm.Array]

‘lecun_normal(…)’

n_steps: int

1

precision: Any

None

use_bias: bool

True

evorl.networks.spectral_norm.Shape

None