Source code for evorl.types

  1"""Common type annotations and data structures."""
  2
  3import dataclasses
  4from functools import wraps
  5from collections.abc import Mapping, Sequence
  6from typing import Any, Union, TypeVar
  7from typing_extensions import dataclass_transform  # pytype: disable=not-supported-yet
  8
  9import chex
 10import jax
 11import jax.numpy as jnp
 12import jax.tree_util as jtu
 13
 14Metrics = Mapping[str, chex.ArrayTree]
 15Observation = Union[chex.Array, Mapping[str, chex.Array]]
 16Action = Union[chex.Array, Mapping[str, chex.Array]]
 17Reward = Union[chex.Array, Mapping[str, chex.Array]]
 18Done = Union[chex.Array, Mapping[str, chex.Array]]
 19PolicyExtraInfo = Mapping[str, Any]
 20ExtraInfo = Mapping[str, Any]
 21RewardDict = Mapping[str, Reward]
 22
 23LossDict = Mapping[str, chex.Array]
 24
 25EnvInternalState = chex.ArrayTree
 26
 27Params = chex.ArrayTree
 28ObsPreprocessorParams = Mapping[str, Any]
 29ActionPostprocessorParams = Mapping[str, Any]
 30
 31AgentID = Any
 32
 33ReplayBufferState = chex.ArrayTree
 34
 35Axis = int | None | Sequence[Any]
 36
 37MISSING_REWARD = -1e10
 38
 39
[docs] 40class PyTreeArrayMixin: 41 """batch operate pytree with jax.Array. 42 43 It assumes all arrays have the same head shape. 44 """ 45 46 def __add__(self, o: chex.ArrayTree) -> chex.ArrayTree: 47 return jtu.tree_map(lambda x, y: x + y, self, o) 48 49 def __sub__(self, o: chex.ArrayTree) -> chex.ArrayTree: 50 return jtu.tree_map(lambda x, y: x - y, self, o) 51 52 def __mul__(self, o: chex.ArrayTree) -> chex.ArrayTree: 53 return jtu.tree_map(lambda x: x * o, self) 54 55 def __neg__(self) -> chex.ArrayTree: 56 return jtu.tree_map(lambda x: -x, self) 57 58 def __truediv__(self, o: chex.ArrayTree) -> chex.ArrayTree: 59 return jtu.tree_map(lambda x: x / o, self) 60
[docs] 61 def reshape(self, shape: Sequence[int]) -> chex.ArrayTree: 62 return jtu.tree_map(lambda x: x.reshape(shape), self)
63
[docs] 64 def slice(self, beg: int, end: int, strides=None) -> chex.ArrayTree: 65 return jtu.tree_map(lambda x: x[beg:end:strides], self)
66
[docs] 67 def take(self, i, axis=0) -> chex.ArrayTree: 68 return jtu.tree_map(lambda x: jnp.take(x, i, axis=axis, mode="wrap"), self)
69
[docs] 70 def concatenate(self, *others: chex.ArrayTree, axis: int = 0) -> chex.ArrayTree: 71 return jtu.tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)
72
[docs] 73 def index_set( 74 self, idx: jax.Array | Sequence[jax.Array], o: chex.ArrayTree 75 ) -> chex.ArrayTree: 76 return jtu.tree_map(lambda x, y: x.at[idx].set(y), self, o)
77
[docs] 78 def index_sum( 79 self, idx: jax.Array | Sequence[jax.Array], o: chex.ArrayTree 80 ) -> chex.ArrayTree: 81 return jtu.tree_map(lambda x, y: x.at[idx].add(y), self, o)
82 83 @property 84 def T(self): 85 return jtu.tree_map(lambda x: x.T, self)
86 87
[docs] 88@jtu.register_pytree_node_class 89class PyTreeDict(dict): 90 """An easydict with pytree support.""" 91 92 def __init__(self, *args, **kwargs): 93 d = dict(*args, **kwargs) 94 95 for k, v in d.items(): 96 setattr(self, k, v) 97 98 @classmethod 99 def _nested_convert(cls, obj): 100 # Currently, only support dict, list, tuple (not include their children classes) 101 if type(obj) is dict: 102 return cls(obj) 103 elif type(obj) is list: 104 return list(cls._nested_convert(item) for item in obj) 105 elif type(obj) is tuple: 106 return tuple(cls._nested_convert(item) for item in obj) 107 else: 108 return obj 109 110 def __setattr__(self, name, value): 111 value = self._nested_convert(value) 112 super().__setattr__(name, value) 113 super().__setitem__(name, value) 114 115 __setitem__ = __setattr__ 116
[docs] 117 def update(self, e=None, **f): 118 d = e or dict() 119 d.update(f) 120 for k in d: 121 setattr(self, k, d[k])
122
[docs] 123 def pop(self, k, d=None): 124 delattr(self, k) 125 return super().pop(k, d)
126
[docs] 127 def copy(self): 128 d = super().copy() # dict 129 return self.__class__(d)
130
[docs] 131 def replace(self, **d): 132 clone = self.copy() 133 clone.update(**d) 134 return clone
135
[docs] 136 def tree_flatten(self): 137 keys = sorted(self.keys()) 138 return tuple(self[k] for k in keys), tuple(keys)
139
[docs] 140 @classmethod 141 def tree_unflatten(cls, aux_data, children): 142 return cls(dict(zip(aux_data, children)))
143 144
[docs] 145@jtu.register_pytree_node_class 146class State(PyTreeDict): 147 """A general State class. 148 149 An alias of PyTreeDict. This class is specfically used for `Workflow` state. 150 """ 151 152 pass
153 154
[docs] 155def pytree_field(*, static=False, **kwargs): 156 """Define a pytree field in our dataclass. 157 158 Args: 159 static: Setting to False will mark the field as static for pytree, where changing data in these fields will cause a re-jit of func. 160 161 Returns: 162 A dataclass field. 163 """ 164 metadata = {"static": static} 165 kwargs.setdefault("metadata", {}).update(metadata) 166 167 return dataclasses.field(**kwargs)
168 169 170_T = TypeVar("T") 171 172
[docs] 173@dataclass_transform(field_specifiers=(pytree_field,)) # type: ignore[literal-required] 174def dataclass(clz: _T, *, pure_data=False, **kwargs) -> _T: 175 # set frozen=True unless manually specified 176 if "frozen" not in kwargs.keys(): 177 kwargs["frozen"] = True 178 179 # Special handling for jax.Array fields with init value. 180 # for name in get_type_hints(clz).keys(): 181 for name in clz.__annotations__.keys(): 182 if hasattr(clz, name): 183 obj = getattr(clz, name) 184 # Although JAX Array is immutable, it is not hashable (__hash__ is None). 185 # To meet the requirements of dataclass, we need to use default_factory. 186 # Note: x=obj is necessary to capture current obj in the closure. 187 if isinstance(obj, jax.Array): 188 setattr(clz, name, pytree_field(default_factory=lambda x=obj: x)) 189 190 data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore 191 meta_fields = [] 192 data_fields = [] 193 for field_info in dataclasses.fields(data_clz): 194 is_static = field_info.metadata.get("static", False) 195 if is_static: 196 meta_fields.append(field_info.name) 197 else: 198 data_fields.append(field_info.name) 199 200 def replace(self, **updates): 201 """Returns a new object replacing the specified fields with new values.""" 202 return dataclasses.replace(self, **updates) 203 204 data_clz.replace = replace 205 206 # TODO: can we always use jax.tree_util.register_dataclass? 207 if pure_data and hasattr(jax.tree_util, "register_dataclass"): 208 # Use the optimized C++ dataclass builtin (jax>=0.4.26) 209 jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields) 210 else: 211 212 def iterate_clz(x): 213 meta = tuple(getattr(x, name) for name in meta_fields) 214 data = tuple(getattr(x, name) for name in data_fields) 215 return data, meta 216 217 def iterate_clz_with_keys(x): 218 meta = tuple(getattr(x, name) for name in meta_fields) 219 data = tuple( 220 (jax.tree_util.GetAttrKey(name), getattr(x, name)) 221 for name in data_fields 222 ) 223 return data, meta 224 225 def clz_from_iterable(meta, data): 226 meta_args = tuple(zip(meta_fields, meta)) 227 data_args = tuple(zip(data_fields, data)) 228 kwargs = dict(meta_args + data_args) 229 return data_clz(**kwargs) 230 231 jax.tree_util.register_pytree_with_keys( 232 data_clz, 233 iterate_clz_with_keys, 234 clz_from_iterable, 235 iterate_clz, 236 ) 237 238 return data_clz # type: ignore
239 240
[docs] 241@dataclass_transform(field_specifiers=(pytree_field,), kw_only_default=True) 242class PyTreeNode: 243 """A pytree dataclass for Node.""" 244 245 def __init_subclass__(cls, kw_only=True, **kwargs): 246 original_post_init = getattr(cls, "__post_init__", None) 247 248 if original_post_init: 249 250 @wraps(original_post_init) 251 def wrapped_post_init(self, *args, **kwargs) -> None: 252 object.__setattr__(self, "_is_in_post_init", True) 253 try: 254 original_post_init(self, *args, **kwargs) 255 finally: 256 object.__setattr__(self, "_is_in_post_init", False) 257 258 cls.__post_init__ = wrapped_post_init 259 260 dataclass(cls, kw_only=kw_only, **kwargs) 261 262 # Allow self.xxx = value in __post_init__ 263 if original_post_init: 264 original_setattr = getattr(cls, "__setattr__") 265 266 def custom_setattr(self, name: str, value: Any) -> None: 267 if getattr(self, "_is_in_post_init", False): 268 # inside __post_init__ 269 object.__setattr__(self, name, value) 270 else: 271 original_setattr(self, name, value) 272 273 cls.__setattr__ = custom_setattr 274
[docs] 275 def set_frozen_attr(self, name, value): 276 """Force set attribute after __init__ of the dataclass.""" 277 object.__setattr__(self, name, value)
278 279
[docs] 280@dataclass_transform(field_specifiers=(pytree_field,), kw_only_default=True) 281class PyTreeData: 282 """A pytree dataclass for Data. 283 284 Like `PyTreeNode`, but all fileds must be set at __init__, and not allow set_frozen_attr() method. 285 """ 286 287 def __init_subclass__(cls, kw_only=True, **kwargs): 288 dataclass(cls, pure_data=True, kw_only=kw_only, **kwargs)