Source code for evorl.utils.ec_utils
1import jax
2
3from jax.flatten_util import ravel_pytree
4from jax.tree_util import tree_leaves
5
6__all__ = ["ParamVectorSpec"]
7
8
[docs]
9class ParamVectorSpec:
10 """Save the structure of the parameters.
11
12 Provide methods to convert between the original tree-like parameter and the flatten parameter vector.
13 """
14
15 def __init__(self, params):
16 """Initialize the ParamVectorSpec.
17
18 Args:
19 params: Provide the structure of the parameters. It should be a single instance of model parameters instead of a batch of parameters.
20 """
21 self._ndim = tree_leaves(params)[0].ndim
22 flat, self.to_tree_fn = ravel_pytree(params)
23 self.vec_size = flat.shape[0]
24 self.to_vec_fn = lambda x: ravel_pytree(x)[0]
25
[docs]
26 def to_vector(self, x) -> jax.Array:
27 """Convert the original params to flatten params.
28
29 see `jax.flatten_util.ravel_pytree`
30
31 Args:
32 x: The original params.
33
34 Returns:
35 Flatten params.
36 """
37 leaves = tree_leaves(x)
38 batch_ndim = leaves[0].ndim - self._ndim
39 vmap_to_vector = self.to_vec_fn
40
41 for _ in range(batch_ndim):
42 vmap_to_vector = jax.vmap(vmap_to_vector)
43
44 return vmap_to_vector(x)
45
[docs]
46 def to_tree(self, x) -> jax.Array:
47 """Convert the flatten params to the original params.
48
49 Args:
50 x: The flatten params.
51
52 Returns:
53 The original params.
54 """
55 leaves = tree_leaves(x)
56 batch_ndim = leaves[0].ndim - self._ndim
57 vmap_to_tree = self.to_tree_fn
58
59 for _ in range(batch_ndim):
60 vmap_to_tree = jax.vmap(vmap_to_tree)
61
62 return vmap_to_tree(x)