import copy
import json
import time
from pathlib import Path
import typing as ty
from typing import Any
from typing import Optional, Union
import numpy as np
import pandas as pd
from PIL import Image
import ablator.utils.file as futils
from ablator.config.proto import RunConfig
from ablator.modules.loggers import LoggerBase
from ablator.modules.loggers.file import FileLogger
from ablator.modules.loggers.tensor import TensorboardLogger
from ablator.modules.metrics.main import Metrics
from ablator.modules.metrics.stores import MovingAverage
[docs]class SummaryLogger:
"""
A logger for training and evaluation summary.
Attributes
----------
SUMMARY_DIR_NAME : str
Name of the summary directory.
RESULTS_JSON_NAME : str
Name of the results JSON file.
LOG_FILE_NAME : str
Name of the log file.
CONFIG_FILE_NAME : str
Name of the configuration file.
METADATA_JSON : str
Name of the metadata JSON file.
CHKPT_DIR_NAMES : list[str]
List of checkpoint directory names.
CHKPT_DIR_VALUES : list[str]
List of checkpoint directory values.
CHKPT_DIRS : dict[str, Path]
Dictionary containing checkpoint directories.
keep_n_checkpoints : int
Number of checkpoints to keep.
log_iteration : int
Current log iteration.
checkpoint_iteration : dict[str, dict[str, int]]
``checkpoint_iteration`` is a dictionary that keeps track of the checkpoint iterations for each directory.
It is used in the ``checkpoint()`` method
to determine the appropriate iteration number for the saved checkpoint.
log_file_path : Path | None
Path to the log file.
dashboard : LoggerBase | None
Dashboard logger.
experiment_dir : Path | None
the trial directory.
result_json_path : Path | None
Path to the results JSON file.
Parameters
----------
run_config : RunConfig
The run configuration.
experiment_dir : str | None | Path
Path to the trial directory, by default ``None``.
resume : bool
Whether to resume from an existing model directory, by default ``False``.
keep_n_checkpoints : int | None
Number of checkpoints to keep, by default ``None``.
verbose : bool
Whether to print messages to the console, by default ``True``.
Raises
------
FileExistsError
If resume is set to ``False`` but the experiment directory already exists.
"""
SUMMARY_DIR_NAME: str = "dashboard"
RESULTS_JSON_NAME: str = "results.json"
LOG_FILE_NAME: str = "train.log"
CONFIG_FILE_NAME: str = "config.yaml"
BACKUP_CONFIG_FILE_NAME: str = "config_backup_{i}.yaml"
METADATA_JSON: str = "metadata.json"
CHKPT_DIR_NAMES: list[str] = ["best", "recent"]
CHKPT_DIR_VALUES: list[str] = ["best_checkpoints", "checkpoints"]
CHKPT_DIRS: dict[str, Path]
def __init__(
self,
run_config: RunConfig,
experiment_dir: str | None | Path = None,
resume: bool = False,
keep_n_checkpoints: int | None = None,
verbose: bool = True,
):
# Initialize a SummaryLogger.
run_config = copy.deepcopy(run_config)
self.uid = run_config.uid
self.keep_n_checkpoints: int = (
keep_n_checkpoints if keep_n_checkpoints is not None else int(1e6)
)
self.log_iteration: int = 0
self.checkpoint_iteration: dict[str, dict[str, int]] = {}
self.log_file_path: Path | None = None
self.dashboard: LoggerBase | None = None
self.experiment_dir: Path | None = None
self.result_json_path: Path | None = None
self.CHKPT_DIRS = {}
_log_msg = ""
if experiment_dir is not None:
self.experiment_dir = Path(experiment_dir)
if not resume and self.experiment_dir.exists():
raise FileExistsError(
f"SummaryLogger: Resume is set to {resume} but {self.experiment_dir} exists."
)
if resume and self.experiment_dir.exists():
_run_config = type(run_config).load(
self.experiment_dir.joinpath(self.CONFIG_FILE_NAME)
)
diffs = run_config.diff_str(_run_config)
if len(diffs) > 0:
i = len(
list(
self.experiment_dir.glob(
self.BACKUP_CONFIG_FILE_NAME.format(i="*")
)
)
)
backup_file_name = self.experiment_dir.joinpath(
self.BACKUP_CONFIG_FILE_NAME.format(i=f"{i:03d}")
)
backup_file_name.write_text(_run_config.to_yaml(), encoding="utf-8")
_log_msg += "Differences between provided configuration and "
_log_msg += f"stored configuration. Creating a configuration backup at {backup_file_name}"
metadata = json.loads(
self.experiment_dir.joinpath(self.METADATA_JSON).read_text(
encoding="utf-8"
)
)
self.checkpoint_iteration = metadata["checkpoint_iteration"]
self.log_iteration = metadata["log_iteration"]
(self.summary_dir, *chkpt_dirs) = futils.make_sub_dirs(
experiment_dir, self.SUMMARY_DIR_NAME, *self.CHKPT_DIR_VALUES
)
for name, path in zip(self.CHKPT_DIR_NAMES, chkpt_dirs):
self.CHKPT_DIRS[name] = path
self.result_json_path = self.experiment_dir / self.RESULTS_JSON_NAME
self.log_file_path = self.experiment_dir.joinpath(self.LOG_FILE_NAME)
self.dashboard = self._make_dashboard(self.summary_dir, run_config)
self._write_config(run_config)
self._update_metadata()
self.logger = FileLogger(path=self.log_file_path, verbose=verbose)
if len(_log_msg) > 0:
self.logger.warn(_log_msg)
def _update_metadata(self):
"""
Update the metadata file.
"""
if self.experiment_dir is None:
return
metadata_path = self.experiment_dir.joinpath(self.METADATA_JSON)
metadata_path.write_text(
json.dumps(
{
"log_iteration": self.log_iteration,
"checkpoint_iteration": self.checkpoint_iteration,
}
),
encoding="utf-8",
)
def _make_dashboard(
self, summary_dir: Path, run_config: RunConfig | None = None
) -> LoggerBase | None:
"""
Make a dashboard logger.
Parameters
----------
summary_dir : Path
Path to the summary directory.
run_config : RunConfig | None
The run configuration, by default ``None``.
Returns
-------
LoggerBase | None
A TensorboardLogger.
"""
if run_config is None or not run_config.tensorboard:
return None
return TensorboardLogger(summary_dir.joinpath("tensorboard"))
def _write_config(self, run_config: RunConfig):
"""
Write the run configuration to the model directory and to the dashboard.
Parameters
----------
run_config : RunConfig
The run configuration.
"""
if self.experiment_dir is None:
return
self.experiment_dir.joinpath(self.CONFIG_FILE_NAME).write_text(
run_config.to_yaml(), encoding="utf-8"
)
if self.dashboard is not None:
self.dashboard.write_config(run_config)
# pylint: disable=too-complex
# flake8: noqa: C901
def _add_metric(self, k: str, v: ty.Any, itr: int):
"""
Add a metric to the dashboard.
Parameters
----------
k : str
The metric name.
v : ty.Any
The metric value.
itr : int
The iteration.
Raises
------
ValueError
If the datatype of the value ``v`` is unsupported.
"""
if self.dashboard is None:
return
if isinstance(v, (list, np.ndarray)):
v = np.array(v)
if v.dtype.kind in {"b", "i", "u", "f", "c"}:
v_dict = {str(i): _v for i, _v in enumerate(v)}
self.dashboard.add_scalars(k, v_dict, itr)
else:
self.dashboard.add_text(k, " ".join(v), itr)
elif isinstance(v, dict):
for sub_k, sub_v in v.items():
self.dashboard.add_scalar(f"{k}_{sub_k}", sub_v, itr)
elif isinstance(v, MovingAverage):
self.dashboard.add_scalar(k, v.get(), itr)
elif isinstance(v, str):
self.dashboard.add_text(k, v, itr)
elif isinstance(v, Image.Image):
self.dashboard.add_image(
k, np.array(v).transpose(2, 0, 1), itr, dataformats="CHW"
)
elif isinstance(v, pd.DataFrame):
self.dashboard.add_table(k, v, itr)
elif isinstance(v, (int, float)):
self.dashboard.add_scalar(k, v, itr)
else:
raise ValueError(
f"Unsupported dashboard value {v}. Must be "
"[int,float, pd.DataFrame, Image.Image, str, "
"MovingAverage, dict[str,float|int], list[float,int], np.ndarray] "
)
def _append_metrics(self, metrics: dict[str, float]):
"""Append metrics to the result json file.
Parameters
----------
metrics : dict[str, float]
The metrics to append.
"""
if self.result_json_path is not None:
_metrics = copy.deepcopy(metrics)
_metrics["timestamp"] = float(time.time())
_metrics_str = futils.dict_to_json(_metrics)
if self.result_json_path.exists():
futils.truncate_utf8_chars(self.result_json_path, "]")
with open(self.result_json_path, "a", encoding="utf-8") as fp:
fp.write(",\n" + _metrics_str + "]")
else:
self.result_json_path.write_text(f"[{_metrics_str}]", encoding="utf-8")
[docs] def update(
self,
metrics: Union[Metrics, dict],
itr: Optional[int] = None,
):
"""Update the dashboard with the given metrics.
write some metrics to json files and update the current metadata (``log_iteration``)
Parameters
----------
metrics : Union[Metrics, dict]
The metrics to update.
itr : Optional[int]
The iteration, by default ``None``.
Raises
------
AssertionError
If the iteration is not greater than the current iteration.
Notes
-----
Attribute ``log_iteration`` is increased by 1 every time ``update()`` is called while training models.
"""
if itr is None:
itr = self.log_iteration
self.log_iteration += 1
else:
assert (
itr > self.log_iteration
), f"Current iteration > {itr}. Can not add metrics."
self.log_iteration = itr
if isinstance(metrics, Metrics):
dict_metrics = metrics.to_dict()
else:
dict_metrics = metrics
for k, v in dict_metrics.items():
if v is not None:
self._add_metric(k, v, itr)
self._append_metrics(dict_metrics)
self._update_metadata()
[docs] def checkpoint(
self,
save_dict: object,
file_name: str,
itr: int | None = None,
is_best: bool = False,
):
"""
Save a checkpoint and update the checkpoint iteration
Saves the model checkpoint in the appropriate directory based on the ``is_best`` parameter.
If ``is_best==True``, the checkpoint is saved in the ``"best"`` directory, indicating the best
performing model so far. Otherwise, the checkpoint is saved in the ``"recent"`` directory,
representing the most recent checkpoint.
The file path for the checkpoint is constructed using the selected directory name (``"best"`` or
``"recent"``), and the file name with the format ``"{file_name}_{itr:010}.pt"``, where ``itr`` is the
iteration number.
The ``checkpoint_iteration`` dictionary is updated with the current iteration number for each
directory. If ``itr`` is not provided, the iteration number is increased by 1 each time a
checkpoint is saved. Otherwise, the iteration number is set to the provided ``itr``.
Parameters
----------
save_dict : object
The object to save.
file_name : str
The file name.
itr : int | None
The iteration. If not provided, the current iteration is incremented by 1, by default ``None``.
is_best : bool
Whether this is the best checkpoint, by default ``False``.
"""
if self.experiment_dir is None:
return
dir_name = "best" if is_best else "recent"
if self.keep_n_checkpoints > 0:
if dir_name not in self.checkpoint_iteration:
self.checkpoint_iteration[dir_name] = {}
if file_name not in self.checkpoint_iteration[dir_name]:
self.checkpoint_iteration[dir_name][file_name] = -1
if itr is None:
self.checkpoint_iteration[dir_name][file_name] += 1
itr = self.checkpoint_iteration[dir_name][file_name]
else:
cur_iter = self.checkpoint_iteration[dir_name][file_name]
assert (
itr > cur_iter
), f"Checkpoint iteration {cur_iter} >= training iteration {itr}. Can not overwrite checkpoint."
self.checkpoint_iteration[dir_name][file_name] = itr
dir_path = self.experiment_dir.joinpath(self.CHKPT_DIRS[dir_name])
file_path = dir_path.joinpath(f"{file_name}_{itr:010}.pt")
assert not file_path.exists(), f"Checkpoint exists: {file_path}"
futils.save_checkpoint(save_dict, file_path.as_posix())
futils.clean_checkpoints(dir_path, self.keep_n_checkpoints)
self._update_metadata()
[docs] def clean_checkpoints(self, keep_n_checkpoints: int):
"""
Clean up checkpoints and keep only the specified number of checkpoints.
Parameters
----------
keep_n_checkpoints : int
Number of checkpoints to keep.
"""
if self.experiment_dir is None:
return
for chkpt_dir in self.CHKPT_DIR_VALUES:
dir_path = self.experiment_dir.joinpath(chkpt_dir)
futils.clean_checkpoints(dir_path, keep_n_checkpoints)
[docs] def info(self, *args: Any, **kwargs: Any):
"""
Log an info to files and to console message using the logger. Here you can use
positional or keyword arguments. Possible parameters are shown in the Parameters section.
Parameters
----------
msg : str
The message to log,
verbose : bool
Whether to print messages to the console, by default ``False``.
"""
self.logger.info(*args, **kwargs)
[docs] def warn(self, *args: Any, **kwargs: Any):
"""
Log a warning message to files and to console using the logger. Here you can use
positional or keyword arguments. Possible parameters are shown in the Parameters section.
Parameters
----------
msg : str
The message to log,
verbose : bool
Whether to print messages to the console, by default ``True``.
"""
self.logger.warn(*args, **kwargs)
[docs] def error(self, *args: Any, **kwargs: Any):
"""
Log an error message to files and to console using the logger. Here you can use
positional or keyword arguments. Possible parameters are shown in the Parameters section.
Parameters
----------
msg : str
The message to log.
"""
self.logger.error(*args, **kwargs)