Source code for evorl.distribution
1from typing import Any
2
3import jax
4import jax.numpy as jnp
5from tensorflow_probability.substrates import jax as tfp
6
7tfd = tfp.distributions # note: tfp use lazy init.
8
9
[docs]
10def get_categorical_dist(logits: jax.Array):
11 """Get a categorical distribution."""
12 return tfd.Categorical(logits=logits)
13
14
[docs]
15def get_tanh_norm_dist(loc: jax.Array, scale: jax.Array, min_scale: float = 1e-3):
16 """Get a tanh transformed normal distribution."""
17 scale = jax.nn.softplus(scale) + min_scale
18 distribution = tfd.Normal(loc=loc, scale=scale)
19 return tfd.Independent(
20 TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1
21 )
22
23
24# class TanhNormal(distrax.Transformed):
25# def __init__(self, loc, scale):
26# super().__init__(
27# distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale),
28# distrax.Block(distrax.Tanh(), ndims=1)
29# )
30
31# def mode(self):
32# loc = self.distribution.mode()
33# return self.bijector.forward(loc)
34
35# def entropy(self, input_hint=None):
36# """
37# No analytical form. use sample to estimate.
38# input_hint: an example sample from the base distribution
39# eg: self.distribution.sample(seed=jax.random.PRNGKey(42))
40# """
41
42# return self.distribution.entropy() + self.bijector.forward_log_det_jacobian(
43# input_hint)
44
45
[docs]
46def get_trancated_norm_dist(loc, scale, low, high):
47 """Get a truncated normal distribution."""
48 return tfd.TruncatedNormal(loc=loc, scale=scale, low=low, high=high)
49
50
[docs]
51class TanhTransformedDistribution(tfd.TransformedDistribution):
52 """Distribution followed by tanh. from acme."""
53
54 def __init__(self, distribution, threshold=0.999, validate_args=False):
55 """Initialize the distribution.
56
57 Args:
58 distribution: The distribution to transform.
59 threshold: Clipping value of the action when computing the logprob.
60 validate_args: Passed to super class.
61 """
62 super().__init__(
63 distribution=distribution,
64 bijector=tfp.bijectors.Tanh(),
65 validate_args=validate_args,
66 )
67 # Computes the log of the average probability distribution outside the
68 # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for
69 # log_prob_left and [atanh(threshold), inf] for log_prob_right.
70 self._threshold = threshold
71 inverse_threshold = self.bijector.inverse(threshold)
72 # average(pdf) = p/epsilon
73 # So log(average(pdf)) = log(p) - log(epsilon)
74 log_epsilon = jnp.log(1.0 - threshold)
75 # Those 2 values are differentiable w.r.t. model parameters, such that the
76 # gradient is defined everywhere.
77 self._log_prob_left = (
78 self.distribution.log_cdf(-inverse_threshold) - log_epsilon
79 )
80 self._log_prob_right = (
81 self.distribution.log_survival_function(inverse_threshold) - log_epsilon
82 )
83
[docs]
84 def log_prob(self, event):
85 # Without this clip there would be NaNs in the inner tf.where and that
86 # causes issues for some reasons.
87 event = jnp.clip(event, -self._threshold, self._threshold)
88 # The inverse image of {threshold} is the interval [atanh(threshold), inf]
89 # which has a probability of "log_prob_right" under the given distribution.
90 return jnp.where(
91 event <= -self._threshold,
92 self._log_prob_left,
93 jnp.where(
94 event >= self._threshold, self._log_prob_right, super().log_prob(event)
95 ),
96 )
97
100
[docs]
101 def entropy(self, seed=None):
102 # We return an estimation using a single sample of the log_det_jacobian.
103 # We can still do some backpropagation with this estimate.
104 return self.distribution.entropy() + self.bijector.forward_log_det_jacobian(
105 self.distribution.sample(seed=seed), event_ndims=0
106 )
107
108 @classmethod
109 def _parameter_properties(cls, dtype: Any | None, num_classes=None):
110 td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
111 del td_properties["bijector"]
112 return td_properties