Source code for evorl.metrics

  1import dataclasses
  2from collections.abc import Callable
  3
  4import chex
  5import jax
  6import jax.numpy as jnp
  7import numpy as np
  8
  9from .distributed import pmean
 10from .types import LossDict, PyTreeData, PyTreeDict
 11
 12
[docs] 13def metric_field( 14 *, 15 reduce_fn: Callable[[chex.Array, str | None], chex.Array] = None, 16 static=False, 17 **kwargs, 18): 19 """Define a metric field in `MetricBase`. 20 21 Args: 22 reduce_fn: A function to reduce the metric value across different devices. For example, `jax.mean` 23 static: Whether the field is static related to pytree. 24 25 Returns: 26 A dataclass field. 27 """ 28 metadata = {"static": static, "reduce_fn": reduce_fn} 29 kwargs.setdefault("metadata", {}).update(metadata) 30 31 return dataclasses.field(**kwargs)
32 33
[docs] 34class MetricBase(PyTreeData, kw_only=True): 35 """Base class for all metrics.""" 36
[docs] 37 def all_reduce(self, dp_axis_name: str | None = None): 38 field_dict = {} 39 for field in dataclasses.fields(self): 40 reduce_fn = field.metadata.get("reduce_fn", None) 41 value = getattr(self, field.name) 42 if dp_axis_name is not None and isinstance(reduce_fn, Callable): 43 value = reduce_fn(value, dp_axis_name) 44 field_dict[field.name] = value 45 46 if len(field_dict) == 0: 47 return self 48 49 return self.replace(**field_dict)
50
[docs] 51 def to_local_dict(self): 52 """Convert the dataclass to native python structures recursively. 53 54 The data in the metric object will be converted to local data types: list, tuple, dict, NamedTuple, etc. Jax array will be convert to numpy array, 55 56 Returns: 57 A converted dict. 58 """ 59 return to_local_dict(self)
60 61
[docs] 62class WorkflowMetric(MetricBase): 63 """Workflow metrics for RLWorkflow. 64 65 Attributes: 66 sampled_timesteps: The total number of sampled timesteps from environments. 67 iterations: The total number of workflow iterations. 68 """ 69 70 sampled_timesteps: chex.Array = jnp.zeros((), dtype=jnp.uint32) 71 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 72 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
73 74
[docs] 75class TrainMetric(MetricBase): 76 """Training metrics for RLWorkflow. 77 78 Attributes: 79 train_episode_return: The return of the training episode. 80 loss: The loss value of the training step. 81 raw_loss_dict: The raw loss dict of the training step. 82 """ 83 84 # manually reduce in the step() 85 train_episode_return: chex.Array | None = None 86 87 # no need reduce_fn since it's already reduced in the step() 88 loss: chex.Array = jnp.zeros(()) 89 raw_loss_dict: LossDict = metric_field(default_factory=PyTreeDict, reduce_fn=pmean)
90 91
[docs] 92class EvaluateMetric(MetricBase): 93 """Evaluation metrics for RLWorkflow. 94 95 Attributes: 96 episode_returns: The return array of evaluation episodes. 97 episode_lengths: The length array of evaluation episodes. 98 """ 99 100 episode_returns: chex.Array = metric_field(reduce_fn=pmean) 101 episode_lengths: chex.Array = metric_field(reduce_fn=pmean)
102 103
[docs] 104class ECWorkflowMetric(MetricBase): 105 """Workflow metrics for ECWorkflow. 106 107 Attributes: 108 best_objective: The best objective value found so far. 109 sampled_episodes: The total number of sampled episodes from environments.. 110 sampled_timesteps_m: The total number of sampled timesteps from environments, measured in millions. 111 iterations: The total number of workflow iterations. 112 """ 113 114 best_objective: chex.Array 115 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 116 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32) 117 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
118 119
[docs] 120class MultiObjectiveECWorkflowMetric(MetricBase): 121 """Workflow metrics for MultiObjectiveECWorkflow. 122 123 Attributes: 124 sampled_episodes: The total number of sampled episodes from environments.. 125 sampled_timesteps_m: The total number of sampled timesteps from environments, measured in millions. 126 iterations: The total number of workflow iterations. 127 """ 128 129 sampled_episodes: chex.Array = jnp.zeros((), dtype=jnp.uint32) 130 sampled_timesteps_m: chex.Array = jnp.zeros((), dtype=jnp.float32) 131 iterations: chex.Array = jnp.zeros((), dtype=jnp.uint32)
132 133
[docs] 134class ECTrainMetric(MetricBase): 135 """Training metrics for ECWorkflow. 136 137 Attributes: 138 objectives: The objective values for current step. 139 ec_metrics: The extra metrics of the training step. 140 """ 141 142 objectives: chex.Array 143 ec_metrics: chex.ArrayTree
144 145
[docs] 146def to_local_dict(obj, *, dict_factory=dict): 147 if not dataclasses.is_dataclass(obj): 148 raise TypeError("to_local_dict() should be called on dataclass instances") 149 return _to_local_dict_inner(obj, dict_factory)
150 151 152def _to_local_dict_inner(obj, dict_factory): 153 if dataclasses.is_dataclass(obj): 154 result = [] 155 for f in dataclasses.fields(obj): 156 value = _to_local_dict_inner(getattr(obj, f.name), dict_factory) 157 result.append((f.name, value)) 158 return dict_factory(result) 159 elif isinstance(obj, tuple) and hasattr(obj, "_fields"): 160 # obj is a namedtuple. Recurse into it, but the returned 161 # object is another namedtuple of the same type. This is 162 # similar to how other list- or tuple-derived classes are 163 # treated (see below), but we just need to create them 164 # differently because a namedtuple's __init__ needs to be 165 # called differently (see bpo-34363). 166 167 # I'm not using namedtuple's _asdict() 168 # method, because: 169 # - it does not recurse in to the namedtuple fields and 170 # convert them to dicts (using dict_factory). 171 # - I don't actually want to return a dict here. The main 172 # use case here is json.dumps, and it handles converting 173 # namedtuples to lists. Admittedly we're losing some 174 # information here when we produce a json list instead of a 175 # dict. Note that if we returned dicts here instead of 176 # namedtuples, we could no longer call asdict() on a data 177 # structure where a namedtuple was used as a dict key. 178 179 return type(obj)(*[_to_local_dict_inner(v, dict_factory) for v in obj]) 180 elif isinstance(obj, (list, tuple)): 181 # Assume we can create an object of this type by passing in a 182 # generator (which is not true for namedtuples, handled 183 # above). 184 return type(obj)(_to_local_dict_inner(v, dict_factory) for v in obj) 185 elif isinstance(obj, PyTreeDict): 186 return dict_factory( 187 ( 188 _to_local_dict_inner(k, dict_factory), 189 _to_local_dict_inner(obj[k], dict_factory), 190 ) 191 for k in sorted(obj.keys()) 192 ) 193 elif isinstance(obj, dict): 194 return type(obj)( 195 ( 196 _to_local_dict_inner(k, dict_factory), 197 _to_local_dict_inner(v, dict_factory), 198 ) 199 for k, v in obj.items() 200 ) 201 else: 202 if isinstance(obj, jax.Array): 203 return np.array(obj) 204 else: 205 return obj