Source code for evorl.utils.running_statistics

  1import chex
  2import jax
  3import jax.numpy as jnp
  4import jax.tree_util as jtu
  5
  6from evorl.types import PyTreeData
  7
  8from .jax_utils import tree_ones_like, tree_zeros_like
  9
 10"""Utility functions to compute running statistics.
 11
 12Modified from https://github.com/google/brax/blob/main/brax/training/acme/running_statistics.py
 13"""
 14
 15
[docs] 16class NestedMeanStd(PyTreeData): 17 """A container for running statistics (mean, std) of possibly nested data.""" 18 19 mean: chex.ArrayTree 20 std: chex.ArrayTree
21 22
[docs] 23class RunningStatisticsState(NestedMeanStd): 24 """Full state of running statistics computation.""" 25 26 count: chex.Array 27 summed_variance: chex.ArrayTree
28 29
[docs] 30def init_state( 31 nest: chex.ArrayTree, int_counter: bool = False 32) -> RunningStatisticsState: 33 """Initializes the running statistics for the given nested structure.""" 34 dtype_int = jnp.int64 if jax.config.jax_enable_x64 else jnp.int32 35 dtype_float = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 36 37 return RunningStatisticsState( 38 count=jnp.zeros((), dtype=dtype_int if int_counter else dtype_float), 39 mean=tree_zeros_like(nest, dtype=dtype_float), 40 summed_variance=tree_zeros_like(nest, dtype=dtype_float), 41 # Initialize with ones to make sure normalization works correctly 42 # in the initial state. 43 std=tree_ones_like(nest, dtype=dtype_float), 44 )
45 46 47def _validate_batch_shapes( 48 batch: chex.Array, reference_sample: chex.Array, batch_dims: tuple[int, ...] 49) -> None: 50 """Verifies shapes of the batch leaves against the reference sample. 51 52 Checks that batch dimensions are the same in all leaves in the batch. 53 Checks that non-batch dimensions for all leaves in the batch are the same 54 as in the reference sample. 55 56 Arguments: 57 batch: the nested batch of data to be verified. 58 reference_sample: the nested array to check non-batch dimensions. 59 batch_dims: a Tuple of indices of batch dimensions in the batch shape. 60 61 Returns: 62 None. 63 """ 64 65 def validate_node_shape(reference_sample: chex.Array, batch: chex.Array) -> None: 66 expected_shape = batch_dims + reference_sample.shape 67 # assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}' 68 chex.assert_shape( 69 batch, expected_shape, custom_message=f"{batch.shape} != {expected_shape}" 70 ) 71 72 jtu.tree_map(validate_node_shape, reference_sample, batch) 73 74
[docs] 75def update( 76 state: RunningStatisticsState, 77 batch: chex.ArrayTree, 78 *, 79 weights: chex.Array | None = None, 80 std_min_value: float = 1e-6, 81 std_max_value: float = 1e6, 82 dp_axis_name: str | None = None, 83 validate_shapes: bool = True, 84) -> RunningStatisticsState: 85 """Updates the running statistics with the given batch of data. 86 87 Note: data batch and state elements (mean, etc.) must have the same structure. 88 89 Note: by default will use int32 for counts and float32 for accumulated 90 variance. This results in an integer overflow after 2^31 data points and 91 degrading precision after 2^24 batch updates or even earlier if variance 92 updates have large dynamic range. 93 To improve precision, consider setting jax_enable_x64 to True, see 94 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision 95 96 Arguments: 97 state: The running statistics before the update. 98 batch: The data to be used to update the running statistics. 99 weights: Weights of the batch data. Should match the batch dimensions. 100 Passing a weight of 2. should be equivalent to updating on the 101 corresponding data point twice. 102 std_min_value: Minimum value for the standard deviation. 103 std_max_value: Maximum value for the standard deviation. 104 dp_axis_name: Name of the pmapped axis, if any. 105 validate_shapes: If true, the shapes of all leaves of the batch will be 106 validated. Enabled by default. Doesn't impact performance when jitted. 107 108 Returns: 109 Updated running statistics. 110 """ 111 # We require exactly the same structure to avoid issues when flattened 112 # batch and state have different order of elements. 113 assert jtu.tree_structure(batch) == jtu.tree_structure(state.mean) 114 batch_shape = jtu.tree_leaves(batch)[0].shape 115 # We assume the batch dimensions always go first. 116 batch_dims = batch_shape[: len(batch_shape) - jtu.tree_leaves(state.mean)[0].ndim] 117 batch_axis = range(len(batch_dims)) 118 if weights is None: 119 step_increment = jnp.prod(jnp.array(batch_dims)) 120 else: 121 step_increment = jnp.sum(weights) 122 if dp_axis_name is not None: 123 step_increment = jax.lax.psum(step_increment, axis_name=dp_axis_name) 124 count = state.count + step_increment 125 126 # Validation is important. If the shapes don't match exactly, but are 127 # compatible, arrays will be silently broadcasted resulting in incorrect 128 # statistics. 129 if validate_shapes: 130 if weights is not None: 131 if weights.shape != batch_dims: 132 raise ValueError(f"{weights.shape} != {batch_dims}") 133 _validate_batch_shapes(batch, state.mean, batch_dims) 134 135 def _compute_node_statistics( 136 mean: chex.Array, summed_variance: chex.Array, batch: chex.Array 137 ) -> tuple[chex.Array, chex.Array]: 138 assert isinstance(mean, chex.Array), type(mean) 139 assert isinstance(summed_variance, chex.Array), type(summed_variance) 140 # The mean and the sum of past variances are updated with Welford's 141 # algorithm using batches (see https://stackoverflow.com/q/56402955). 142 diff_to_old_mean = batch - mean 143 if weights is not None: 144 expanded_weights = jnp.reshape( 145 weights, list(weights.shape) + [1] * (batch.ndim - weights.ndim) 146 ) 147 diff_to_old_mean = diff_to_old_mean * expanded_weights 148 mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count 149 if dp_axis_name is not None: 150 mean_update = jax.lax.psum(mean_update, axis_name=dp_axis_name) 151 mean = mean + mean_update 152 153 diff_to_new_mean = batch - mean 154 variance_update = diff_to_old_mean * diff_to_new_mean 155 variance_update = jnp.sum(variance_update, axis=batch_axis) 156 if dp_axis_name is not None: 157 variance_update = jax.lax.psum(variance_update, axis_name=dp_axis_name) 158 summed_variance = summed_variance + variance_update 159 return mean, summed_variance 160 161 updated_stats = jtu.tree_map( 162 _compute_node_statistics, state.mean, state.summed_variance, batch 163 ) 164 # Extract `mean` and `summed_variance` from `updated_stats` nest. 165 mean = jtu.tree_map(lambda _, x: x[0], state.mean, updated_stats) 166 summed_variance = jtu.tree_map(lambda _, x: x[1], state.mean, updated_stats) 167 168 def compute_std(summed_variance: chex.Array, std: chex.Array) -> chex.Array: 169 assert isinstance(summed_variance, chex.Array) 170 # Summed variance can get negative due to rounding errors. 171 summed_variance = jnp.maximum(summed_variance, 0) 172 std = jnp.sqrt(summed_variance / count) 173 std = jnp.clip(std, std_min_value, std_max_value) 174 return std 175 176 std = jtu.tree_map(compute_std, summed_variance, state.std) 177 178 return RunningStatisticsState( 179 count=count, mean=mean, summed_variance=summed_variance, std=std 180 )
181 182
[docs] 183def normalize( 184 batch: chex.Array, 185 mean_std: NestedMeanStd, 186 eps: float = 1e-8, 187 max_abs_value: float | None = None, 188) -> chex.Array: 189 """Normalizes data using running statistics.""" 190 191 def normalize_leaf( 192 data: chex.Array, mean: chex.Array, std: chex.Array 193 ) -> chex.Array: 194 # Only normalize inexact 195 if not jnp.issubdtype(data.dtype, jnp.inexact): 196 return data 197 data = (data - mean) / (std + eps) 198 if max_abs_value is not None: 199 # TODO: remove pylint directive 200 data = jnp.clip(data, -max_abs_value, +max_abs_value) 201 return data 202 203 return jtu.tree_map(normalize_leaf, batch, mean_std.mean, mean_std.std)
204 205
[docs] 206def denormalize(batch: chex.Array, mean_std: NestedMeanStd) -> chex.Array: 207 """Denormalizes values in a nested structure using the given mean/std. 208 209 Only values of inexact types are denormalized. 210 See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type 211 hierarchy. 212 213 Args: 214 batch: a nested structure containing batch of data. 215 mean_std: mean and standard deviation used for denormalization. 216 217 Returns: 218 Nested structure with denormalized values. 219 """ 220 221 def denormalize_leaf( 222 data: chex.Array, mean: chex.Array, std: chex.Array 223 ) -> chex.Array: 224 # Only denormalize inexact 225 if not jnp.issubdtype(data.dtype, jnp.inexact): 226 return data 227 return data * std + mean 228 229 return jtu.tree_map(denormalize_leaf, batch, mean_std.mean, mean_std.std)