evorl.utils.running_statistics

Module Contents

Classes

NestedMeanStd

A container for running statistics (mean, std) of possibly nested data.

RunningStatisticsState

Full state of running statistics computation.

Functions

denormalize

Denormalizes values in a nested structure using the given mean/std.

init_state

Initializes the running statistics for the given nested structure.

normalize

Normalizes data using running statistics.

update

Updates the running statistics with the given batch of data.

API

class evorl.utils.running_statistics.NestedMeanStd[source]

Bases: evorl.types.PyTreeData

A container for running statistics (mean, std) of possibly nested data.

mean: chex.ArrayTree

None

std: chex.ArrayTree

None

class evorl.utils.running_statistics.RunningStatisticsState[source]

Bases: evorl.utils.running_statistics.NestedMeanStd

Full state of running statistics computation.

count: chex.Array

None

summed_variance: chex.ArrayTree

None

evorl.utils.running_statistics.denormalize(batch: chex.Array, mean_std: evorl.utils.running_statistics.NestedMeanStd) chex.Array[source]

Denormalizes values in a nested structure using the given mean/std.

Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy.

Parameters:
  • batch – a nested structure containing batch of data.

  • mean_std – mean and standard deviation used for denormalization.

Returns:

Nested structure with denormalized values.

evorl.utils.running_statistics.init_state(nest: chex.ArrayTree, int_counter: bool = False) evorl.utils.running_statistics.RunningStatisticsState[source]

Initializes the running statistics for the given nested structure.

evorl.utils.running_statistics.normalize(batch: chex.Array, mean_std: evorl.utils.running_statistics.NestedMeanStd, eps: float = 1e-08, max_abs_value: float | None = None) chex.Array[source]

Normalizes data using running statistics.

evorl.utils.running_statistics.update(state: evorl.utils.running_statistics.RunningStatisticsState, batch: chex.ArrayTree, *, weights: chex.Array | None = None, std_min_value: float = 1e-06, std_max_value: float = 1000000.0, dp_axis_name: str | None = None, validate_shapes: bool = True) evorl.utils.running_statistics.RunningStatisticsState[source]

Updates the running statistics with the given batch of data.

Note: data batch and state elements (mean, etc.) must have the same structure.

Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

Parameters:
  • state – The running statistics before the update.

  • batch – The data to be used to update the running statistics.

  • weights – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.

  • std_min_value – Minimum value for the standard deviation.

  • std_max_value – Maximum value for the standard deviation.

  • dp_axis_name – Name of the pmapped axis, if any.

  • validate_shapes – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.

Returns:

Updated running statistics.