evorl.utils.running_statistics¶
Module Contents¶
Classes¶
A container for running statistics (mean, std) of possibly nested data. |
|
Full state of running statistics computation. |
Functions¶
Denormalizes values in a nested structure using the given mean/std. |
|
Initializes the running statistics for the given nested structure. |
|
Normalizes data using running statistics. |
|
Updates the running statistics with the given batch of data. |
API¶
- class evorl.utils.running_statistics.NestedMeanStd[source]¶
Bases:
evorl.types.PyTreeDataA container for running statistics (mean, std) of possibly nested data.
- mean: chex.ArrayTree¶
None
- std: chex.ArrayTree¶
None
- class evorl.utils.running_statistics.RunningStatisticsState[source]¶
Bases:
evorl.utils.running_statistics.NestedMeanStdFull state of running statistics computation.
- count: chex.Array¶
None
- summed_variance: chex.ArrayTree¶
None
- evorl.utils.running_statistics.denormalize(batch: chex.Array, mean_std: evorl.utils.running_statistics.NestedMeanStd) chex.Array[source]¶
Denormalizes values in a nested structure using the given mean/std.
Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type hierarchy.
- Parameters:
batch – a nested structure containing batch of data.
mean_std – mean and standard deviation used for denormalization.
- Returns:
Nested structure with denormalized values.
- evorl.utils.running_statistics.init_state(nest: chex.ArrayTree, int_counter: bool = False) evorl.utils.running_statistics.RunningStatisticsState[source]¶
Initializes the running statistics for the given nested structure.
- evorl.utils.running_statistics.normalize(batch: chex.Array, mean_std: evorl.utils.running_statistics.NestedMeanStd, eps: float = 1e-08, max_abs_value: float | None = None) chex.Array[source]¶
Normalizes data using running statistics.
- evorl.utils.running_statistics.update(state: evorl.utils.running_statistics.RunningStatisticsState, batch: chex.ArrayTree, *, weights: chex.Array | None = None, std_min_value: float = 1e-06, std_max_value: float = 1000000.0, dp_axis_name: str | None = None, validate_shapes: bool = True) evorl.utils.running_statistics.RunningStatisticsState[source]¶
Updates the running statistics with the given batch of data.
Note: data batch and state elements (mean, etc.) must have the same structure.
Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
- Parameters:
state – The running statistics before the update.
batch – The data to be used to update the running statistics.
weights – Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice.
std_min_value – Minimum value for the standard deviation.
std_max_value – Maximum value for the standard deviation.
dp_axis_name – Name of the pmapped axis, if any.
validate_shapes – If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn’t impact performance when jitted.
- Returns:
Updated running statistics.