Source code for evorl.recorders.wandb_recorder
1from collections.abc import Mapping
2from typing import Any
3import sys
4
5import jax.tree_util as jtu
6import numpy as np
7import pandas as pd
8import wandb
9
10from .recorder import Recorder
11
12
[docs]
13class WandbRecorder(Recorder):
14 """Recorder for Weights & Biases."""
15
16 def __init__(self, *, project, name, config, tags, path, **wandb_kwargs):
17 self.wandb_kwargs = {
18 "project": project,
19 "name": name,
20 "config": config,
21 "tags": tags,
22 "dir": path,
23 **wandb_kwargs,
24 }
25
[docs]
26 def init(self) -> None:
27 wandb.init(**self.wandb_kwargs)
28
[docs]
29 def write(self, data: Mapping[str, Any], step: int | None = None) -> None:
30 data = jtu.tree_map(lambda x: _convert_data(x), data)
31 wandb.log(data, step=step)
32
[docs]
33 def close(self):
34 ext_type, _, _ = sys.exc_info()
35 if ext_type is not None:
36 wandb.finish(exit_code=1)
37 else:
38 wandb.finish()
39
40
41def _convert_data(val: Any):
42 if isinstance(val, pd.Series):
43 return wandb.Histogram(val)
44 elif isinstance(val, pd.DataFrame):
45 return wandb.Table(dataframe=val)
46 else:
47 return val
48
49
[docs]
50def add_prefix(data: dict, prefix: str):
51 """Add prefix to the keys of a dictionary."""
52 return {f"{prefix}/{k}": v for k, v in data.items()}
53
54
[docs]
55def get_1d_array_statistics(data, histogram=False):
56 """Get raw value and statistics of a 1D array.
57
58 Helper function for logging in WandB.
59
60 Args:
61 data: 1D numpy array. If data has multiple dimensions, it will be viewed as flattened.
62 histogram: If True, return raw data in `pd.Series`, which will be futher converted to histogram in `WandBRecorder`.
63
64 Returns:
65 A dictionary containing min, max, mean, and optional raw data.
66 """
67 if data is None:
68 res = dict(min=None, max=None, mean=None)
69 if histogram:
70 res["val"] = pd.Series()
71 return res
72
73 res = dict(
74 min=np.nanmin(data).tolist(),
75 max=np.nanmax(data).tolist(),
76 mean=np.nanmean(data).tolist(),
77 )
78
79 if histogram:
80 res["val"] = pd.Series(data)
81
82 return res
83
84
[docs]
85def get_1d_array(data):
86 """Get statistics of a 1D array.
87
88 Similar to `get_1d_array_statistics`, but instead of recording histogram, WandB will record the raw data.
89 """
90 if data is None:
91 res = dict(min=None, max=None, mean=None, val=[])
92 return res
93
94 res = dict(
95 min=np.nanmin(data).tolist(),
96 max=np.nanmax(data).tolist(),
97 mean=np.nanmean(data).tolist(),
98 )
99
100 res["val"] = data
101
102 return res