Source code for evorl.recorders.log_recorder
1import logging
2from collections.abc import Mapping
3from typing import Any
4
5import jax.tree_util as jtu
6import numpy as np
7import pandas as pd
8
9# from pprint import pformat
10import yaml
11
12from .recorder import Recorder
13
14# class SubLoggerFilter(logging.Filter):
15# def filter(self, record):
16# # Only allow log records that have the sub-logger's name
17# return record.name == self.name
18
19
[docs]
20class LogRecorder(Recorder):
21 """Log file recorder."""
22
23 def __init__(self, log_path: str, console: bool = True):
24 """Initialize the log recorder.
25
26 Args:
27 log_path: The path to the log file.
28 console: Whether to print the log to the console. Defaults to True.
29 """
30 self.log_path = log_path
31 self.console = console
32
[docs]
33 def init(self) -> None:
34 self.logger = logging.getLogger("LogRecorder")
35
36 self.file_handler = logging.FileHandler(self.log_path, mode="w")
37 # use root logger formatter (usually set by hydra)
38 self.file_handler.setFormatter(logging.getLogger().handlers[0].formatter)
39 self.logger.addHandler(self.file_handler)
40
41 if not self.console:
42 self.logger.propagate = False
43
[docs]
44 def write(self, data: Mapping[str, Any], step: int | None = None) -> None:
45 data = jtu.tree_map(lambda x: _convert_data(x), data)
46 formatted_data = f"iteration {step}:\n" + yaml.dump(data, indent=2)
47 self.logger.info(formatted_data)
48
[docs]
49 def close(self) -> None:
50 self.file_handler.close()
51
52
53def _convert_data(val):
54 if isinstance(val, np.ndarray):
55 return val.tolist()
56 elif isinstance(val, np.generic):
57 return val.item()
58 elif isinstance(val, pd.Series) or isinstance(val, pd.DataFrame):
59 # escape the special data for wandb
60 return None
61 else:
62 return val