Source code for evorl.recorders.log_recorder

 1import logging
 2from collections.abc import Mapping
 3from typing import Any
 4
 5import jax.tree_util as jtu
 6import numpy as np
 7import pandas as pd
 8
 9# from pprint import pformat
10import yaml
11
12from .recorder import Recorder
13
14# class SubLoggerFilter(logging.Filter):
15#     def filter(self, record):
16#         # Only allow log records that have the sub-logger's name
17#         return record.name == self.name
18
19
[docs] 20class LogRecorder(Recorder): 21 """Log file recorder.""" 22 23 def __init__(self, log_path: str, console: bool = True): 24 """Initialize the log recorder. 25 26 Args: 27 log_path: The path to the log file. 28 console: Whether to print the log to the console. Defaults to True. 29 """ 30 self.log_path = log_path 31 self.console = console 32
[docs] 33 def init(self) -> None: 34 self.logger = logging.getLogger("LogRecorder") 35 36 self.file_handler = logging.FileHandler(self.log_path, mode="w") 37 # use root logger formatter (usually set by hydra) 38 self.file_handler.setFormatter(logging.getLogger().handlers[0].formatter) 39 self.logger.addHandler(self.file_handler) 40 41 if not self.console: 42 self.logger.propagate = False
43
[docs] 44 def write(self, data: Mapping[str, Any], step: int | None = None) -> None: 45 data = jtu.tree_map(lambda x: _convert_data(x), data) 46 formatted_data = f"iteration {step}:\n" + yaml.dump(data, indent=2) 47 self.logger.info(formatted_data)
48
[docs] 49 def close(self) -> None: 50 self.file_handler.close()
51 52 53def _convert_data(val): 54 if isinstance(val, np.ndarray): 55 return val.tolist() 56 elif isinstance(val, np.generic): 57 return val.item() 58 elif isinstance(val, pd.Series) or isinstance(val, pd.DataFrame): 59 # escape the special data for wandb 60 return None 61 else: 62 return val