evorl.networks

Package Contents

Classes

MLP

MLP module.

SNMLP

MLP module with Spectral Normalization.

StaticLayerNorm

Layer normalization with fixed scale and bias.

Functions

make_discrete_q_network

Creates a Q network for discrete action space: (obs) -> q_values.

make_mlp

Creates an MLP network.

make_policy_network

Creates a policy network.

make_q_network

Creates a Q network: (obs, action) -> value.

make_v_network

Creates a V network: (obs) -> value.

make_vmap_mlp

Creates multiple MLP networks in parallel.

API

class evorl.networks.MLP[source]

Bases: flax.linen.Module

MLP module.

activation: evorl.networks.linear.ActivationFn

None

activation_final: evorl.networks.linear.ActivationFn | None

None

kernel_init: evorl.networks.linear.Initializer

‘lecun_uniform(…)’

layer_sizes: collections.abc.Sequence[int]

None

norm_layer: flax.linen.Module | None

None

use_bias: bool

True

class evorl.networks.SNMLP[source]

Bases: flax.linen.Module

MLP module with Spectral Normalization.

activation: evorl.networks.linear.ActivationFn

None

activation_final: evorl.networks.linear.ActivationFn | None

None

kernel_init: evorl.networks.linear.Initializer

‘lecun_uniform(…)’

layer_sizes: collections.abc.Sequence[int]

None

use_bias: bool

True

class evorl.networks.StaticLayerNorm[source]

Bases: flax.linen.LayerNorm

Layer normalization with fixed scale and bias.

fixed_bias: jax.Array

‘field(…)’

fixed_scale: jax.Array

‘field(…)’

use_bias: bool

False

use_scale: bool

False

evorl.networks.make_discrete_q_network(action_size: int, n_stack: int = 1, hidden_layer_sizes: collections.abc.Sequence[int] = (256, 256), activation: evorl.networks.linear.ActivationFn = nn.relu, kernel_init: evorl.networks.linear.Initializer = jax.nn.initializers.lecun_uniform(), norm_layer_type: str = 'none', obs_key: str = '') flax.linen.Module[source]

Creates a Q network for discrete action space: (obs) -> q_values.

evorl.networks.make_mlp(layer_sizes: collections.abc.Sequence[int], activation: evorl.networks.linear.ActivationFn = nn.relu, kernel_init: evorl.networks.linear.Initializer = jax.nn.initializers.lecun_uniform(), activation_final: evorl.networks.linear.ActivationFn | None = None, use_bias: bool = True, norm_layer_type: str = 'none') flax.linen.Module[source]

Creates an MLP network.

evorl.networks.make_policy_network(action_size: int, hidden_layer_sizes: collections.abc.Sequence[int] = (256, 256), use_bias: bool = True, activation: evorl.networks.linear.ActivationFn = nn.relu, activation_final: evorl.networks.linear.ActivationFn | None = None, norm_layer_type: str = 'none', obs_key: str = '') flax.linen.Module[source]

Creates a policy network.

evorl.networks.make_q_network(n_stack: int = 1, hidden_layer_sizes: collections.abc.Sequence[int] = (256, 256), activation: evorl.networks.linear.ActivationFn = nn.relu, kernel_init: evorl.networks.linear.Initializer = jax.nn.initializers.lecun_uniform(), norm_layer_type: str = 'none', obs_key: str = '') flax.linen.Module[source]

Creates a Q network: (obs, action) -> value.

evorl.networks.make_v_network(hidden_layer_sizes: collections.abc.Sequence[int] = (256, 256), activation: evorl.networks.linear.ActivationFn = nn.relu, kernel_init: evorl.networks.linear.Initializer = jax.nn.initializers.lecun_uniform(), norm_layer_type: str = 'none', obs_key: str = '') flax.linen.Module[source]

Creates a V network: (obs) -> value.

evorl.networks.make_vmap_mlp(layer_sizes: collections.abc.Sequence[int], activation: evorl.networks.linear.ActivationFn = nn.relu, kernel_init: evorl.networks.linear.Initializer = jax.nn.initializers.lecun_uniform(), activation_final: evorl.networks.linear.ActivationFn | None = None, use_bias: bool = True, norm_layer_type: str = 'none', out_axes: int = -2)[source]

Creates multiple MLP networks in parallel.