Source code for evorl.utils.ec_utils

 1import jax
 2
 3from jax.flatten_util import ravel_pytree
 4from jax.tree_util import tree_leaves
 5
 6__all__ = ["ParamVectorSpec"]
 7
 8
[docs] 9class ParamVectorSpec: 10 """Save the structure of the parameters. 11 12 Provide methods to convert between the original tree-like parameter and the flatten parameter vector. 13 """ 14 15 def __init__(self, params): 16 """Initialize the ParamVectorSpec. 17 18 Args: 19 params: Provide the structure of the parameters. It should be a single instance of model parameters instead of a batch of parameters. 20 """ 21 self._ndim = tree_leaves(params)[0].ndim 22 flat, self.to_tree_fn = ravel_pytree(params) 23 self.vec_size = flat.shape[0] 24 self.to_vec_fn = lambda x: ravel_pytree(x)[0] 25
[docs] 26 def to_vector(self, x) -> jax.Array: 27 """Convert the original params to flatten params. 28 29 see `jax.flatten_util.ravel_pytree` 30 31 Args: 32 x: The original params. 33 34 Returns: 35 Flatten params. 36 """ 37 leaves = tree_leaves(x) 38 batch_ndim = leaves[0].ndim - self._ndim 39 vmap_to_vector = self.to_vec_fn 40 41 for _ in range(batch_ndim): 42 vmap_to_vector = jax.vmap(vmap_to_vector) 43 44 return vmap_to_vector(x)
45
[docs] 46 def to_tree(self, x) -> jax.Array: 47 """Convert the flatten params to the original params. 48 49 Args: 50 x: The flatten params. 51 52 Returns: 53 The original params. 54 """ 55 leaves = tree_leaves(x) 56 batch_ndim = leaves[0].ndim - self._ndim 57 vmap_to_tree = self.to_tree_fn 58 59 for _ in range(batch_ndim): 60 vmap_to_tree = jax.vmap(vmap_to_tree) 61 62 return vmap_to_tree(x)