1"""Common type annotations and data structures."""
2
3import dataclasses
4from functools import wraps
5from collections.abc import Mapping, Sequence
6from typing import Any, Union, TypeVar
7from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
8
9import chex
10import jax
11import jax.numpy as jnp
12import jax.tree_util as jtu
13
14Metrics = Mapping[str, chex.ArrayTree]
15Observation = Union[chex.Array, Mapping[str, chex.Array]]
16Action = Union[chex.Array, Mapping[str, chex.Array]]
17Reward = Union[chex.Array, Mapping[str, chex.Array]]
18Done = Union[chex.Array, Mapping[str, chex.Array]]
19PolicyExtraInfo = Mapping[str, Any]
20ExtraInfo = Mapping[str, Any]
21RewardDict = Mapping[str, Reward]
22
23LossDict = Mapping[str, chex.Array]
24
25EnvInternalState = chex.ArrayTree
26
27Params = chex.ArrayTree
28ObsPreprocessorParams = Mapping[str, Any]
29ActionPostprocessorParams = Mapping[str, Any]
30
31AgentID = Any
32
33ReplayBufferState = chex.ArrayTree
34
35Axis = int | None | Sequence[Any]
36
37MISSING_REWARD = -1e10
38
39
[docs]
40class PyTreeArrayMixin:
41 """batch operate pytree with jax.Array.
42
43 It assumes all arrays have the same head shape.
44 """
45
46 def __add__(self, o: chex.ArrayTree) -> chex.ArrayTree:
47 return jtu.tree_map(lambda x, y: x + y, self, o)
48
49 def __sub__(self, o: chex.ArrayTree) -> chex.ArrayTree:
50 return jtu.tree_map(lambda x, y: x - y, self, o)
51
52 def __mul__(self, o: chex.ArrayTree) -> chex.ArrayTree:
53 return jtu.tree_map(lambda x: x * o, self)
54
55 def __neg__(self) -> chex.ArrayTree:
56 return jtu.tree_map(lambda x: -x, self)
57
58 def __truediv__(self, o: chex.ArrayTree) -> chex.ArrayTree:
59 return jtu.tree_map(lambda x: x / o, self)
60
[docs]
61 def reshape(self, shape: Sequence[int]) -> chex.ArrayTree:
62 return jtu.tree_map(lambda x: x.reshape(shape), self)
63
[docs]
64 def slice(self, beg: int, end: int, strides=None) -> chex.ArrayTree:
65 return jtu.tree_map(lambda x: x[beg:end:strides], self)
66
[docs]
67 def take(self, i, axis=0) -> chex.ArrayTree:
68 return jtu.tree_map(lambda x: jnp.take(x, i, axis=axis, mode="wrap"), self)
69
[docs]
70 def concatenate(self, *others: chex.ArrayTree, axis: int = 0) -> chex.ArrayTree:
71 return jtu.tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)
72
[docs]
73 def index_set(
74 self, idx: jax.Array | Sequence[jax.Array], o: chex.ArrayTree
75 ) -> chex.ArrayTree:
76 return jtu.tree_map(lambda x, y: x.at[idx].set(y), self, o)
77
[docs]
78 def index_sum(
79 self, idx: jax.Array | Sequence[jax.Array], o: chex.ArrayTree
80 ) -> chex.ArrayTree:
81 return jtu.tree_map(lambda x, y: x.at[idx].add(y), self, o)
82
83 @property
84 def T(self):
85 return jtu.tree_map(lambda x: x.T, self)
86
87
[docs]
88@jtu.register_pytree_node_class
89class PyTreeDict(dict):
90 """An easydict with pytree support."""
91
92 def __init__(self, *args, **kwargs):
93 d = dict(*args, **kwargs)
94
95 for k, v in d.items():
96 setattr(self, k, v)
97
98 @classmethod
99 def _nested_convert(cls, obj):
100 # Currently, only support dict, list, tuple (not include their children classes)
101 if type(obj) is dict:
102 return cls(obj)
103 elif type(obj) is list:
104 return list(cls._nested_convert(item) for item in obj)
105 elif type(obj) is tuple:
106 return tuple(cls._nested_convert(item) for item in obj)
107 else:
108 return obj
109
110 def __setattr__(self, name, value):
111 value = self._nested_convert(value)
112 super().__setattr__(name, value)
113 super().__setitem__(name, value)
114
115 __setitem__ = __setattr__
116
[docs]
117 def update(self, e=None, **f):
118 d = e or dict()
119 d.update(f)
120 for k in d:
121 setattr(self, k, d[k])
122
[docs]
123 def pop(self, k, d=None):
124 delattr(self, k)
125 return super().pop(k, d)
126
[docs]
127 def copy(self):
128 d = super().copy() # dict
129 return self.__class__(d)
130
[docs]
131 def replace(self, **d):
132 clone = self.copy()
133 clone.update(**d)
134 return clone
135
[docs]
136 def tree_flatten(self):
137 keys = sorted(self.keys())
138 return tuple(self[k] for k in keys), tuple(keys)
139
[docs]
140 @classmethod
141 def tree_unflatten(cls, aux_data, children):
142 return cls(dict(zip(aux_data, children)))
143
144
[docs]
145@jtu.register_pytree_node_class
146class State(PyTreeDict):
147 """A general State class.
148
149 An alias of PyTreeDict. This class is specfically used for `Workflow` state.
150 """
151
152 pass
153
154
[docs]
155def pytree_field(*, static=False, **kwargs):
156 """Define a pytree field in our dataclass.
157
158 Args:
159 static: Setting to False will mark the field as static for pytree, where changing data in these fields will cause a re-jit of func.
160
161 Returns:
162 A dataclass field.
163 """
164 metadata = {"static": static}
165 kwargs.setdefault("metadata", {}).update(metadata)
166
167 return dataclasses.field(**kwargs)
168
169
170_T = TypeVar("T")
171
172
[docs]
173@dataclass_transform(field_specifiers=(pytree_field,)) # type: ignore[literal-required]
174def dataclass(clz: _T, *, pure_data=False, **kwargs) -> _T:
175 # set frozen=True unless manually specified
176 if "frozen" not in kwargs.keys():
177 kwargs["frozen"] = True
178
179 # Special handling for jax.Array fields with init value.
180 # for name in get_type_hints(clz).keys():
181 for name in clz.__annotations__.keys():
182 if hasattr(clz, name):
183 obj = getattr(clz, name)
184 # Although JAX Array is immutable, it is not hashable (__hash__ is None).
185 # To meet the requirements of dataclass, we need to use default_factory.
186 # Note: x=obj is necessary to capture current obj in the closure.
187 if isinstance(obj, jax.Array):
188 setattr(clz, name, pytree_field(default_factory=lambda x=obj: x))
189
190 data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
191 meta_fields = []
192 data_fields = []
193 for field_info in dataclasses.fields(data_clz):
194 is_static = field_info.metadata.get("static", False)
195 if is_static:
196 meta_fields.append(field_info.name)
197 else:
198 data_fields.append(field_info.name)
199
200 def replace(self, **updates):
201 """Returns a new object replacing the specified fields with new values."""
202 return dataclasses.replace(self, **updates)
203
204 data_clz.replace = replace
205
206 # TODO: can we always use jax.tree_util.register_dataclass?
207 if pure_data and hasattr(jax.tree_util, "register_dataclass"):
208 # Use the optimized C++ dataclass builtin (jax>=0.4.26)
209 jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
210 else:
211
212 def iterate_clz(x):
213 meta = tuple(getattr(x, name) for name in meta_fields)
214 data = tuple(getattr(x, name) for name in data_fields)
215 return data, meta
216
217 def iterate_clz_with_keys(x):
218 meta = tuple(getattr(x, name) for name in meta_fields)
219 data = tuple(
220 (jax.tree_util.GetAttrKey(name), getattr(x, name))
221 for name in data_fields
222 )
223 return data, meta
224
225 def clz_from_iterable(meta, data):
226 meta_args = tuple(zip(meta_fields, meta))
227 data_args = tuple(zip(data_fields, data))
228 kwargs = dict(meta_args + data_args)
229 return data_clz(**kwargs)
230
231 jax.tree_util.register_pytree_with_keys(
232 data_clz,
233 iterate_clz_with_keys,
234 clz_from_iterable,
235 iterate_clz,
236 )
237
238 return data_clz # type: ignore
239
240
[docs]
241@dataclass_transform(field_specifiers=(pytree_field,), kw_only_default=True)
242class PyTreeNode:
243 """A pytree dataclass for Node."""
244
245 def __init_subclass__(cls, kw_only=True, **kwargs):
246 original_post_init = getattr(cls, "__post_init__", None)
247
248 if original_post_init:
249
250 @wraps(original_post_init)
251 def wrapped_post_init(self, *args, **kwargs) -> None:
252 object.__setattr__(self, "_is_in_post_init", True)
253 try:
254 original_post_init(self, *args, **kwargs)
255 finally:
256 object.__setattr__(self, "_is_in_post_init", False)
257
258 cls.__post_init__ = wrapped_post_init
259
260 dataclass(cls, kw_only=kw_only, **kwargs)
261
262 # Allow self.xxx = value in __post_init__
263 if original_post_init:
264 original_setattr = getattr(cls, "__setattr__")
265
266 def custom_setattr(self, name: str, value: Any) -> None:
267 if getattr(self, "_is_in_post_init", False):
268 # inside __post_init__
269 object.__setattr__(self, name, value)
270 else:
271 original_setattr(self, name, value)
272
273 cls.__setattr__ = custom_setattr
274
[docs]
275 def set_frozen_attr(self, name, value):
276 """Force set attribute after __init__ of the dataclass."""
277 object.__setattr__(self, name, value)
278
279
[docs]
280@dataclass_transform(field_specifiers=(pytree_field,), kw_only_default=True)
281class PyTreeData:
282 """A pytree dataclass for Data.
283
284 Like `PyTreeNode`, but all fileds must be set at __init__, and not allow set_frozen_attr() method.
285 """
286
287 def __init_subclass__(cls, kw_only=True, **kwargs):
288 dataclass(cls, pure_data=True, kw_only=kw_only, **kwargs)