1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5
6from evorl.types import PyTreeData
7
8from .jax_utils import tree_ones_like, tree_zeros_like
9
10"""Utility functions to compute running statistics.
11
12Modified from https://github.com/google/brax/blob/main/brax/training/acme/running_statistics.py
13"""
14
15
[docs]
16class NestedMeanStd(PyTreeData):
17 """A container for running statistics (mean, std) of possibly nested data."""
18
19 mean: chex.ArrayTree
20 std: chex.ArrayTree
21
22
[docs]
23class RunningStatisticsState(NestedMeanStd):
24 """Full state of running statistics computation."""
25
26 count: chex.Array
27 summed_variance: chex.ArrayTree
28
29
[docs]
30def init_state(
31 nest: chex.ArrayTree, int_counter: bool = False
32) -> RunningStatisticsState:
33 """Initializes the running statistics for the given nested structure."""
34 dtype_int = jnp.int64 if jax.config.jax_enable_x64 else jnp.int32
35 dtype_float = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
36
37 return RunningStatisticsState(
38 count=jnp.zeros((), dtype=dtype_int if int_counter else dtype_float),
39 mean=tree_zeros_like(nest, dtype=dtype_float),
40 summed_variance=tree_zeros_like(nest, dtype=dtype_float),
41 # Initialize with ones to make sure normalization works correctly
42 # in the initial state.
43 std=tree_ones_like(nest, dtype=dtype_float),
44 )
45
46
47def _validate_batch_shapes(
48 batch: chex.Array, reference_sample: chex.Array, batch_dims: tuple[int, ...]
49) -> None:
50 """Verifies shapes of the batch leaves against the reference sample.
51
52 Checks that batch dimensions are the same in all leaves in the batch.
53 Checks that non-batch dimensions for all leaves in the batch are the same
54 as in the reference sample.
55
56 Arguments:
57 batch: the nested batch of data to be verified.
58 reference_sample: the nested array to check non-batch dimensions.
59 batch_dims: a Tuple of indices of batch dimensions in the batch shape.
60
61 Returns:
62 None.
63 """
64
65 def validate_node_shape(reference_sample: chex.Array, batch: chex.Array) -> None:
66 expected_shape = batch_dims + reference_sample.shape
67 # assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}'
68 chex.assert_shape(
69 batch, expected_shape, custom_message=f"{batch.shape} != {expected_shape}"
70 )
71
72 jtu.tree_map(validate_node_shape, reference_sample, batch)
73
74
[docs]
75def update(
76 state: RunningStatisticsState,
77 batch: chex.ArrayTree,
78 *,
79 weights: chex.Array | None = None,
80 std_min_value: float = 1e-6,
81 std_max_value: float = 1e6,
82 dp_axis_name: str | None = None,
83 validate_shapes: bool = True,
84) -> RunningStatisticsState:
85 """Updates the running statistics with the given batch of data.
86
87 Note: data batch and state elements (mean, etc.) must have the same structure.
88
89 Note: by default will use int32 for counts and float32 for accumulated
90 variance. This results in an integer overflow after 2^31 data points and
91 degrading precision after 2^24 batch updates or even earlier if variance
92 updates have large dynamic range.
93 To improve precision, consider setting jax_enable_x64 to True, see
94 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
95
96 Arguments:
97 state: The running statistics before the update.
98 batch: The data to be used to update the running statistics.
99 weights: Weights of the batch data. Should match the batch dimensions.
100 Passing a weight of 2. should be equivalent to updating on the
101 corresponding data point twice.
102 std_min_value: Minimum value for the standard deviation.
103 std_max_value: Maximum value for the standard deviation.
104 dp_axis_name: Name of the pmapped axis, if any.
105 validate_shapes: If true, the shapes of all leaves of the batch will be
106 validated. Enabled by default. Doesn't impact performance when jitted.
107
108 Returns:
109 Updated running statistics.
110 """
111 # We require exactly the same structure to avoid issues when flattened
112 # batch and state have different order of elements.
113 assert jtu.tree_structure(batch) == jtu.tree_structure(state.mean)
114 batch_shape = jtu.tree_leaves(batch)[0].shape
115 # We assume the batch dimensions always go first.
116 batch_dims = batch_shape[: len(batch_shape) - jtu.tree_leaves(state.mean)[0].ndim]
117 batch_axis = range(len(batch_dims))
118 if weights is None:
119 step_increment = jnp.prod(jnp.array(batch_dims))
120 else:
121 step_increment = jnp.sum(weights)
122 if dp_axis_name is not None:
123 step_increment = jax.lax.psum(step_increment, axis_name=dp_axis_name)
124 count = state.count + step_increment
125
126 # Validation is important. If the shapes don't match exactly, but are
127 # compatible, arrays will be silently broadcasted resulting in incorrect
128 # statistics.
129 if validate_shapes:
130 if weights is not None:
131 if weights.shape != batch_dims:
132 raise ValueError(f"{weights.shape} != {batch_dims}")
133 _validate_batch_shapes(batch, state.mean, batch_dims)
134
135 def _compute_node_statistics(
136 mean: chex.Array, summed_variance: chex.Array, batch: chex.Array
137 ) -> tuple[chex.Array, chex.Array]:
138 assert isinstance(mean, chex.Array), type(mean)
139 assert isinstance(summed_variance, chex.Array), type(summed_variance)
140 # The mean and the sum of past variances are updated with Welford's
141 # algorithm using batches (see https://stackoverflow.com/q/56402955).
142 diff_to_old_mean = batch - mean
143 if weights is not None:
144 expanded_weights = jnp.reshape(
145 weights, list(weights.shape) + [1] * (batch.ndim - weights.ndim)
146 )
147 diff_to_old_mean = diff_to_old_mean * expanded_weights
148 mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count
149 if dp_axis_name is not None:
150 mean_update = jax.lax.psum(mean_update, axis_name=dp_axis_name)
151 mean = mean + mean_update
152
153 diff_to_new_mean = batch - mean
154 variance_update = diff_to_old_mean * diff_to_new_mean
155 variance_update = jnp.sum(variance_update, axis=batch_axis)
156 if dp_axis_name is not None:
157 variance_update = jax.lax.psum(variance_update, axis_name=dp_axis_name)
158 summed_variance = summed_variance + variance_update
159 return mean, summed_variance
160
161 updated_stats = jtu.tree_map(
162 _compute_node_statistics, state.mean, state.summed_variance, batch
163 )
164 # Extract `mean` and `summed_variance` from `updated_stats` nest.
165 mean = jtu.tree_map(lambda _, x: x[0], state.mean, updated_stats)
166 summed_variance = jtu.tree_map(lambda _, x: x[1], state.mean, updated_stats)
167
168 def compute_std(summed_variance: chex.Array, std: chex.Array) -> chex.Array:
169 assert isinstance(summed_variance, chex.Array)
170 # Summed variance can get negative due to rounding errors.
171 summed_variance = jnp.maximum(summed_variance, 0)
172 std = jnp.sqrt(summed_variance / count)
173 std = jnp.clip(std, std_min_value, std_max_value)
174 return std
175
176 std = jtu.tree_map(compute_std, summed_variance, state.std)
177
178 return RunningStatisticsState(
179 count=count, mean=mean, summed_variance=summed_variance, std=std
180 )
181
182
[docs]
183def normalize(
184 batch: chex.Array,
185 mean_std: NestedMeanStd,
186 eps: float = 1e-8,
187 max_abs_value: float | None = None,
188) -> chex.Array:
189 """Normalizes data using running statistics."""
190
191 def normalize_leaf(
192 data: chex.Array, mean: chex.Array, std: chex.Array
193 ) -> chex.Array:
194 # Only normalize inexact
195 if not jnp.issubdtype(data.dtype, jnp.inexact):
196 return data
197 data = (data - mean) / (std + eps)
198 if max_abs_value is not None:
199 # TODO: remove pylint directive
200 data = jnp.clip(data, -max_abs_value, +max_abs_value)
201 return data
202
203 return jtu.tree_map(normalize_leaf, batch, mean_std.mean, mean_std.std)
204
205
[docs]
206def denormalize(batch: chex.Array, mean_std: NestedMeanStd) -> chex.Array:
207 """Denormalizes values in a nested structure using the given mean/std.
208
209 Only values of inexact types are denormalized.
210 See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type
211 hierarchy.
212
213 Args:
214 batch: a nested structure containing batch of data.
215 mean_std: mean and standard deviation used for denormalization.
216
217 Returns:
218 Nested structure with denormalized values.
219 """
220
221 def denormalize_leaf(
222 data: chex.Array, mean: chex.Array, std: chex.Array
223 ) -> chex.Array:
224 # Only denormalize inexact
225 if not jnp.issubdtype(data.dtype, jnp.inexact):
226 return data
227 return data * std + mean
228
229 return jtu.tree_map(denormalize_leaf, batch, mean_std.mean, mean_std.std)