Source code for evorl.envs.space

  1import chex
  2import jax
  3import jax.numpy as jnp
  4import jax.tree_util as jtu
  5
  6from evorl.types import PyTreeData
  7from evorl.utils.jax_utils import rng_split_like_tree
  8
  9
[docs] 10class Space(PyTreeData): 11 """Base class for Space like `gym.Space`.""" 12 13 @property 14 def shape(self) -> chex.Shape: 15 """Get the shape of the space.""" 16 raise NotImplementedError 17
[docs] 18 def sample(self, key: chex.PRNGKey) -> chex.Array: 19 """Randomly sample a data in this space. 20 21 Returns: 22 A sample from the space 23 """ 24 raise NotImplementedError
25
[docs] 26 def contains(self, x: chex.Array) -> bool: 27 """Determine whether the input is in the space. 28 29 Returns: 30 A boolean value about whether x is in the space. 31 """ 32 raise NotImplementedError
33 34
[docs] 35class Box(Space): 36 """Continuous space in R^n. 37 38 Attributes: 39 low: The lower bounds of the box. 40 high: The upper bounds of the box. 41 """ 42 43 low: chex.Array 44 high: chex.Array 45
[docs] 46 def sample(self, key: chex.PRNGKey) -> chex.Array: 47 return jax.random.uniform( 48 key, 49 shape=self.low.shape, 50 dtype=self.low.dtype, 51 minval=self.low, 52 maxval=self.high, 53 )
54 55 @property 56 def shape(self) -> chex.Shape: 57 return self.low.shape 58
[docs] 59 def contains(self, x: chex.Array) -> chex.Array: 60 return jnp.logical_and(jnp.all(x >= self.low), jnp.all(x <= self.high))
61 62
[docs] 63class Discrete(Space): 64 """Discrete space in {0, 1, ..., n-1}. 65 66 Attributes: 67 n: The number of discrete values. 68 """ 69 70 n: int 71
[docs] 72 def sample(self, key: chex.PRNGKey) -> chex.Array: 73 return jax.random.randint(key, shape=(), minval=0, maxval=self.n)
74 75 @property 76 def shape(self) -> chex.Shape: 77 return () 78
[docs] 79 def contains(self, x: chex.Array) -> chex.Array: 80 return jnp.logical_and(x >= 0, x < self.n)
81 82
[docs] 83class SpaceContainer(Space): 84 """Container for structural spaces. 85 86 Attributes: 87 spaces: a pytree of spaces. 88 """ 89 90 spaces: chex.ArrayTree 91
[docs] 92 def sample(self, key: chex.PRNGKey) -> chex.ArrayTree: 93 keys = rng_split_like_tree( 94 key, self.spaces, is_leaf=lambda s: isinstance(s, Space) 95 ) 96 return jtu.tree_map( 97 lambda s, key: s.sample(key), 98 self.spaces, 99 keys, 100 is_leaf=lambda s: isinstance(s, Space), 101 )
102 103 @property 104 def shape(self) -> chex.ArrayTree: 105 return jtu.tree_map( 106 lambda s: s.shape, 107 self.spaces, 108 is_leaf=lambda s: isinstance(s, Space), 109 ) 110
[docs] 111 def contains(self, data: chex.ArrayTree) -> chex.ArrayTree: 112 return jtu.tree_all( 113 jtu.tree_map( 114 lambda s, x: s.contains(x), 115 self.spaces, 116 data, 117 is_leaf=lambda s: isinstance(s, Space), 118 ) 119 )
120 121
[docs] 122def is_leaf_space(space): 123 return isinstance(space, Space) and not isinstance(space, SpaceContainer)