Source code for evorl.envs.space
1import chex
2import jax
3import jax.numpy as jnp
4import jax.tree_util as jtu
5
6from evorl.types import PyTreeData
7from evorl.utils.jax_utils import rng_split_like_tree
8
9
[docs]
10class Space(PyTreeData):
11 """Base class for Space like `gym.Space`."""
12
13 @property
14 def shape(self) -> chex.Shape:
15 """Get the shape of the space."""
16 raise NotImplementedError
17
[docs]
18 def sample(self, key: chex.PRNGKey) -> chex.Array:
19 """Randomly sample a data in this space.
20
21 Returns:
22 A sample from the space
23 """
24 raise NotImplementedError
25
[docs]
26 def contains(self, x: chex.Array) -> bool:
27 """Determine whether the input is in the space.
28
29 Returns:
30 A boolean value about whether x is in the space.
31 """
32 raise NotImplementedError
33
34
[docs]
35class Box(Space):
36 """Continuous space in R^n.
37
38 Attributes:
39 low: The lower bounds of the box.
40 high: The upper bounds of the box.
41 """
42
43 low: chex.Array
44 high: chex.Array
45
[docs]
46 def sample(self, key: chex.PRNGKey) -> chex.Array:
47 return jax.random.uniform(
48 key,
49 shape=self.low.shape,
50 dtype=self.low.dtype,
51 minval=self.low,
52 maxval=self.high,
53 )
54
55 @property
56 def shape(self) -> chex.Shape:
57 return self.low.shape
58
[docs]
59 def contains(self, x: chex.Array) -> chex.Array:
60 return jnp.logical_and(jnp.all(x >= self.low), jnp.all(x <= self.high))
61
62
[docs]
63class Discrete(Space):
64 """Discrete space in {0, 1, ..., n-1}.
65
66 Attributes:
67 n: The number of discrete values.
68 """
69
70 n: int
71
[docs]
72 def sample(self, key: chex.PRNGKey) -> chex.Array:
73 return jax.random.randint(key, shape=(), minval=0, maxval=self.n)
74
75 @property
76 def shape(self) -> chex.Shape:
77 return ()
78
[docs]
79 def contains(self, x: chex.Array) -> chex.Array:
80 return jnp.logical_and(x >= 0, x < self.n)
81
82
[docs]
83class SpaceContainer(Space):
84 """Container for structural spaces.
85
86 Attributes:
87 spaces: a pytree of spaces.
88 """
89
90 spaces: chex.ArrayTree
91
[docs]
92 def sample(self, key: chex.PRNGKey) -> chex.ArrayTree:
93 keys = rng_split_like_tree(
94 key, self.spaces, is_leaf=lambda s: isinstance(s, Space)
95 )
96 return jtu.tree_map(
97 lambda s, key: s.sample(key),
98 self.spaces,
99 keys,
100 is_leaf=lambda s: isinstance(s, Space),
101 )
102
103 @property
104 def shape(self) -> chex.ArrayTree:
105 return jtu.tree_map(
106 lambda s: s.shape,
107 self.spaces,
108 is_leaf=lambda s: isinstance(s, Space),
109 )
110
[docs]
111 def contains(self, data: chex.ArrayTree) -> chex.ArrayTree:
112 return jtu.tree_all(
113 jtu.tree_map(
114 lambda s, x: s.contains(x),
115 self.spaces,
116 data,
117 is_leaf=lambda s: isinstance(s, Space),
118 )
119 )
120
121
[docs]
122def is_leaf_space(space):
123 return isinstance(space, Space) and not isinstance(space, SpaceContainer)