1import dataclasses
2from collections.abc import Callable
3
4import chex
5import jax
6import jax.numpy as jnp
7import numpy as np
8
9from .distributed import pmean
10from .types import LossDict, PyTreeData, PyTreeDict
11
12
[docs]
13def metric_field(
14 *,
15 reduce_fn: Callable[[chex.Array, str | None], chex.Array] = None,
16 static=False,
17 **kwargs,
18):
19 """Define a metric field in `MetricBase`.
20
21 Args:
22 reduce_fn: A function to reduce the metric value across different devices. For example, `jax.mean`
23 static: Whether the field is static related to pytree.
24
25 Returns:
26 A dataclass field.
27 """
28 metadata = {"static": static, "reduce_fn": reduce_fn}
29 kwargs.setdefault("metadata", {}).update(metadata)
30
31 return dataclasses.field(**kwargs)
32
33
[docs]
34class MetricBase(PyTreeData, kw_only=True):
35 """Base class for all metrics."""
36
[docs]
37 def all_reduce(self, dp_axis_name: str | None = None):
38 field_dict = {}
39 for field in dataclasses.fields(self):
40 reduce_fn = field.metadata.get("reduce_fn", None)
41 value = getattr(self, field.name)
42 if dp_axis_name is not None and isinstance(reduce_fn, Callable):
43 value = reduce_fn(value, dp_axis_name)
44 field_dict[field.name] = value
45
46 if len(field_dict) == 0:
47 return self
48
49 return self.replace(**field_dict)
50
[docs]
51 def to_local_dict(self):
52 """Convert the dataclass to native python structures recursively.
53
54 The data in the metric object will be converted to local data types: list, tuple, dict, NamedTuple, etc. Jax array will be convert to numpy array,
55
56 Returns:
57 A converted dict.
58 """
59 return to_local_dict(self)
60
61
[docs]
62class WorkflowMetric(MetricBase):
63 """Workflow metrics for RLWorkflow.
64
65 Attributes:
66 sampled_timesteps: The total number of sampled timesteps from environments.
67 iterations: The total number of workflow iterations.
68 """
69
70 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32)
71 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
72 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
73
74
[docs]
75class TrainMetric(MetricBase):
76 """Training metrics for RLWorkflow.
77
78 Attributes:
79 train_episode_return: The return of the training episode.
80 loss: The loss value of the training step.
81 raw_loss_dict: The raw loss dict of the training step.
82 """
83
84 # manually reduce in the step()
85 train_episode_return: chex.Array | None = None
86
87 # no need reduce_fn since it's already reduced in the step()
88 loss: chex.Array = jnp.zeros(())
89 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
90
91
[docs]
92class EvaluateMetric(MetricBase):
93 """Evaluation metrics for RLWorkflow.
94
95 Attributes:
96 episode_returns: The return array of evaluation episodes.
97 episode_lengths: The length array of evaluation episodes.
98 """
99
100 episode_returns: chex.Array = metric_field(reduce_fn=pmean)
101 episode_lengths: chex.Array = metric_field(reduce_fn=pmean)
102
103
[docs]
104class ECWorkflowMetric(MetricBase):
105 """Workflow metrics for ECWorkflow.
106
107 Attributes:
108 best_objective: The best objective value found so far.
109 sampled_episodes: The total number of sampled episodes from environments..
110 sampled_timesteps_m: The total number of sampled timesteps from environments, measured in millions.
111 iterations: The total number of workflow iterations.
112 """
113
114 best_objective: chex.Array
115 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
116 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32)
117 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
118
119
[docs]
120class MultiObjectiveECWorkflowMetric(MetricBase):
121 """Workflow metrics for MultiObjectiveECWorkflow.
122
123 Attributes:
124 sampled_episodes: The total number of sampled episodes from environments..
125 sampled_timesteps_m: The total number of sampled timesteps from environments, measured in millions.
126 iterations: The total number of workflow iterations.
127 """
128
129 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32)
130 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32)
131 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
132
133
[docs]
134class ECTrainMetric(MetricBase):
135 """Training metrics for ECWorkflow.
136
137 Attributes:
138 objectives: The objective values for current step.
139 ec_metrics: The extra metrics of the training step.
140 """
141
142 objectives: chex.Array
143 ec_metrics: chex.ArrayTree
144
145
[docs]
146def to_local_dict(obj, *, dict_factory=dict):
147 if not dataclasses.is_dataclass(obj):
148 raise TypeError("to_local_dict() should be called on dataclass instances")
149 return _to_local_dict_inner(obj, dict_factory)
150
151
152def _to_local_dict_inner(obj, dict_factory):
153 if dataclasses.is_dataclass(obj):
154 result = []
155 for f in dataclasses.fields(obj):
156 value = _to_local_dict_inner(getattr(obj, f.name), dict_factory)
157 result.append((f.name, value))
158 return dict_factory(result)
159 elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
160 # obj is a namedtuple. Recurse into it, but the returned
161 # object is another namedtuple of the same type. This is
162 # similar to how other list- or tuple-derived classes are
163 # treated (see below), but we just need to create them
164 # differently because a namedtuple's __init__ needs to be
165 # called differently (see bpo-34363).
166
167 # I'm not using namedtuple's _asdict()
168 # method, because:
169 # - it does not recurse in to the namedtuple fields and
170 # convert them to dicts (using dict_factory).
171 # - I don't actually want to return a dict here. The main
172 # use case here is json.dumps, and it handles converting
173 # namedtuples to lists. Admittedly we're losing some
174 # information here when we produce a json list instead of a
175 # dict. Note that if we returned dicts here instead of
176 # namedtuples, we could no longer call asdict() on a data
177 # structure where a namedtuple was used as a dict key.
178
179 return type(obj)(*[_to_local_dict_inner(v, dict_factory) for v in obj])
180 elif isinstance(obj, (list, tuple)):
181 # Assume we can create an object of this type by passing in a
182 # generator (which is not true for namedtuples, handled
183 # above).
184 return type(obj)(_to_local_dict_inner(v, dict_factory) for v in obj)
185 elif isinstance(obj, PyTreeDict):
186 return dict_factory(
187 (
188 _to_local_dict_inner(k, dict_factory),
189 _to_local_dict_inner(obj[k], dict_factory),
190 )
191 for k in sorted(obj.keys())
192 )
193 elif isinstance(obj, dict):
194 return type(obj)(
195 (
196 _to_local_dict_inner(k, dict_factory),
197 _to_local_dict_inner(v, dict_factory),
198 )
199 for k, v in obj.items()
200 )
201 else:
202 if isinstance(obj, jax.Array):
203 return np.array(obj)
204 else:
205 return obj