Source code for evorl.recorders.wandb_recorder

  1from collections.abc import Mapping
  2from typing import Any
  3import sys
  4
  5import jax.tree_util as jtu
  6import numpy as np
  7import pandas as pd
  8import wandb
  9
 10from .recorder import Recorder
 11
 12
[docs] 13class WandbRecorder(Recorder): 14 """Recorder for Weights & Biases.""" 15 16 def __init__(self, *, project, name, config, tags, path, **wandb_kwargs): 17 self.wandb_kwargs = { 18 "project": project, 19 "name": name, 20 "config": config, 21 "tags": tags, 22 "dir": path, 23 **wandb_kwargs, 24 } 25
[docs] 26 def init(self) -> None: 27 wandb.init(**self.wandb_kwargs)
28
[docs] 29 def write(self, data: Mapping[str, Any], step: int | None = None) -> None: 30 data = jtu.tree_map(lambda x: _convert_data(x), data) 31 wandb.log(data, step=step)
32
[docs] 33 def close(self): 34 ext_type, _, _ = sys.exc_info() 35 if ext_type is not None: 36 wandb.finish(exit_code=1) 37 else: 38 wandb.finish()
39 40 41def _convert_data(val: Any): 42 if isinstance(val, pd.Series): 43 return wandb.Histogram(val) 44 elif isinstance(val, pd.DataFrame): 45 return wandb.Table(dataframe=val) 46 else: 47 return val 48 49
[docs] 50def add_prefix(data: dict, prefix: str): 51 """Add prefix to the keys of a dictionary.""" 52 return {f"{prefix}/{k}": v for k, v in data.items()}
53 54
[docs] 55def get_1d_array_statistics(data, histogram=False): 56 """Get raw value and statistics of a 1D array. 57 58 Helper function for logging in WandB. 59 60 Args: 61 data: 1D numpy array. If data has multiple dimensions, it will be viewed as flattened. 62 histogram: If True, return raw data in `pd.Series`, which will be futher converted to histogram in `WandBRecorder`. 63 64 Returns: 65 A dictionary containing min, max, mean, and optional raw data. 66 """ 67 if data is None: 68 res = dict(min=None, max=None, mean=None) 69 if histogram: 70 res["val"] = pd.Series() 71 return res 72 73 res = dict( 74 min=np.nanmin(data).tolist(), 75 max=np.nanmax(data).tolist(), 76 mean=np.nanmean(data).tolist(), 77 ) 78 79 if histogram: 80 res["val"] = pd.Series(data) 81 82 return res
83 84
[docs] 85def get_1d_array(data): 86 """Get statistics of a 1D array. 87 88 Similar to `get_1d_array_statistics`, but instead of recording histogram, WandB will record the raw data. 89 """ 90 if data is None: 91 res = dict(min=None, max=None, mean=None, val=[]) 92 return res 93 94 res = dict( 95 min=np.nanmin(data).tolist(), 96 max=np.nanmax(data).tolist(), 97 mean=np.nanmean(data).tolist(), 98 ) 99 100 res["val"] = data 101 102 return res