Source code for evorl.distribution

  1from typing import Any
  2
  3import jax
  4import jax.numpy as jnp
  5from tensorflow_probability.substrates import jax as tfp
  6
  7tfd = tfp.distributions  # note: tfp use lazy init.
  8
  9
[docs] 10def get_categorical_dist(logits: jax.Array): 11 """Get a categorical distribution.""" 12 return tfd.Categorical(logits=logits)
13 14
[docs] 15def get_tanh_norm_dist(loc: jax.Array, scale: jax.Array, min_scale: float = 1e-3): 16 """Get a tanh transformed normal distribution.""" 17 scale = jax.nn.softplus(scale) + min_scale 18 distribution = tfd.Normal(loc=loc, scale=scale) 19 return tfd.Independent( 20 TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1 21 )
22 23 24# class TanhNormal(distrax.Transformed): 25# def __init__(self, loc, scale): 26# super().__init__( 27# distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale), 28# distrax.Block(distrax.Tanh(), ndims=1) 29# ) 30 31# def mode(self): 32# loc = self.distribution.mode() 33# return self.bijector.forward(loc) 34 35# def entropy(self, input_hint=None): 36# """ 37# No analytical form. use sample to estimate. 38# input_hint: an example sample from the base distribution 39# eg: self.distribution.sample(seed=jax.random.PRNGKey(42)) 40# """ 41 42# return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( 43# input_hint) 44 45
[docs] 46def get_trancated_norm_dist(loc, scale, low, high): 47 """Get a truncated normal distribution.""" 48 return tfd.TruncatedNormal(loc=loc, scale=scale, low=low, high=high)
49 50
[docs] 51class TanhTransformedDistribution(tfd.TransformedDistribution): 52 """Distribution followed by tanh. from acme.""" 53 54 def __init__(self, distribution, threshold=0.999, validate_args=False): 55 """Initialize the distribution. 56 57 Args: 58 distribution: The distribution to transform. 59 threshold: Clipping value of the action when computing the logprob. 60 validate_args: Passed to super class. 61 """ 62 super().__init__( 63 distribution=distribution, 64 bijector=tfp.bijectors.Tanh(), 65 validate_args=validate_args, 66 ) 67 # Computes the log of the average probability distribution outside the 68 # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for 69 # log_prob_left and [atanh(threshold), inf] for log_prob_right. 70 self._threshold = threshold 71 inverse_threshold = self.bijector.inverse(threshold) 72 # average(pdf) = p/epsilon 73 # So log(average(pdf)) = log(p) - log(epsilon) 74 log_epsilon = jnp.log(1.0 - threshold) 75 # Those 2 values are differentiable w.r.t. model parameters, such that the 76 # gradient is defined everywhere. 77 self._log_prob_left = ( 78 self.distribution.log_cdf(-inverse_threshold) - log_epsilon 79 ) 80 self._log_prob_right = ( 81 self.distribution.log_survival_function(inverse_threshold) - log_epsilon 82 ) 83
[docs] 84 def log_prob(self, event): 85 # Without this clip there would be NaNs in the inner tf.where and that 86 # causes issues for some reasons. 87 event = jnp.clip(event, -self._threshold, self._threshold) 88 # The inverse image of {threshold} is the interval [atanh(threshold), inf] 89 # which has a probability of "log_prob_right" under the given distribution. 90 return jnp.where( 91 event <= -self._threshold, 92 self._log_prob_left, 93 jnp.where( 94 event >= self._threshold, self._log_prob_right, super().log_prob(event) 95 ), 96 )
97
[docs] 98 def mode(self): 99 return self.bijector.forward(self.distribution.mode())
100
[docs] 101 def entropy(self, seed=None): 102 # We return an estimation using a single sample of the log_det_jacobian. 103 # We can still do some backpropagation with this estimate. 104 return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( 105 self.distribution.sample(seed=seed), event_ndims=0 106 )
107 108 @classmethod 109 def _parameter_properties(cls, dtype: Any | None, num_classes=None): 110 td_properties = super()._parameter_properties(dtype, num_classes=num_classes) 111 del td_properties["bijector"] 112 return td_properties