Source code for ablator.main.model.main

from contextlib import nullcontext
import copy
import math
import traceback
import typing as ty
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from functools import cached_property
from pathlib import Path

import setproctitle
import torch
from torch import nn
from torch.utils.data import DataLoader

import ablator.utils.base as butils
from ablator.config.proto import Optim, RunConfig
from ablator.modules.loggers.main import SummaryLogger
from ablator.modules.metrics.main import Metrics
from ablator.utils.base import Dummy, Lock
from ablator.utils.progress_bar import ProgressBar, RemoteProgressBar


[docs]class EvaluationError(Exception): pass
[docs]class TrainPlateauError(Exception): pass
[docs]class LogStepError(Exception): pass
[docs]class CheckpointNotFoundError(FileNotFoundError): pass
# pylint: disable=too-many-instance-attributes
[docs]class ModelBase(ABC): """ Base class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality. Attributes ---------- model_class : Type[nn.Module] The class definition of the model's structure, which is a subclass of ``nn.Module``. run_config : RunConfig An instance of ``RunConfig`` containing configuration details. train_dataloader : DataLoader A DataLoader object responsible for model training. val_dataloader : Optional[DataLoader] An optional DataLoader object used for model evaluation. test_dataloader : Optional[DataLoader] An optional DataLoader object used for model testing. logger : Union[SummaryLogger, Dummy] Records information on the program's operation and model training, such as progress and performance metrics. device : str The type of device used for running the experiment. i.e. ``"cuda"``, ``"cpu"``, ``"cuda:0"``. model_dir : Path The model directory. experiment_dir : Path The experiment directory. verbose : bool If ``True``, prints additional information while training. Only applied for the master process. amp : bool If ``True``, apply automatic mixed precision training, otherwise default precision. random_seed : Optional[int] Sets the seed for generating random numbers. progress_bar : Union[ProgressBar, Dummy] An optional instance of ``ProgressBar`` that displays real-time information during training. e.g. time remaining. Only applied for the master process. current_checkpoint : Optional[Path] Directory for the current checkpoint file, by default ``None``. train_metrics : Metrics Training metrics including model information. i.e. learning rate and loss value. eval_metrics : Metrics | None Evaluation metrics for when a ``val_dataloader`` is provided. current_state : dict The currrent state of the model, including run_config, metrics and other necessary states. learning_rate : float The current learning rate. total_steps : int The total steps for the training process. epochs : int The total epochs for the training process. current_iteration : int The current iteration of training. best_iteration : int The iteration with the best loss value. best_metrics : dict[str, float] The lowest optim values encountered during training. optim_metric_name : str | None The name of the optimization metric. optim_metric_direction : Optim | None The optimization direction Parameters ---------- model_class : type[nn.Module] The base class for user's model, which defines the neural network. Notes ----- 1. Class properties are simply listed by name. Please check out property docstring for more information. 2. Users must implement the abstract methods to customize the model's behavior. 3. Mixed precision training enables some operations to use the ``torch.float32`` datatype and other operations use lower precision floating point datatype ``torch.float16``. This is for saving time and reducing memory usage. Ordinarily, "automatic mixed precision training" means training with ``torch.autocast`` and ``torch.cuda.amp.GradScaler`` together. More information: https://pytorch.org/docs/stable/amp.html """ def __init__( self, model_class: type[nn.Module], ): self.model_class = model_class self.run_config: RunConfig self.train_dataloader: DataLoader self.val_dataloader: DataLoader | None self.test_dataloader: DataLoader | None self.logger: ty.Union[SummaryLogger, Dummy] self.device: str self.experiment_dir: Path | None = None self.verbose: ty.Literal["progress", "console", "silent"] self.amp: bool self.random_seed: ty.Optional[int] self.progress_bar: ProgressBar | butils.Dummy self.optim_metric_name: str | None self.optim_metric_direction: Optim | None self.current_checkpoint: Path | None # Runtime metrics self.train_metrics: Metrics self.eval_metrics: Metrics | None self.current_state: dict # stats self.learning_rate: float self.total_steps: int self.epochs: int # self.current_epoch: int self.current_iteration: int self.best_iteration: int | None self.best_metrics: dict[str, float] # Attributes updated during training self._running_stats_names: list[str] = [ "best_iteration", "best_metrics", "current_iteration", "learning_rate", ] # Attributes derived from configuration self._derived_stats_names: list[str] = [ "epoch_len", "current_epoch", "log_itr", "eval_itr", "uid", "total_steps", "epochs", ] self._cached_properties: list[str] = [ "epoch_len", "eval_itr", "log_itr", ] self._overridable_stats_names: list[str] = [] # internal properties self._uid: str self._autocast: torch.autocast self._is_init = False # functions that call init_state self._init_function_names = ["init_state"] def _init_attrs(self): self.current_iteration = 0 self.best_iteration = None self.best_metrics = {} self.learning_rate = float("inf") self.current_state = {} self.eval_metrics = None self.current_checkpoint = None self.val_dataloader = None self.test_dataloader = None def _reset_cached_attributes(self): for p in self._cached_properties: if hasattr(self, p): delattr(self, p) @property def train_stats(self) -> OrderedDict: """ Returns an ordered dictionary containing the current training statistics. Returns ------- OrderedDict An ordered dictionary with the following keys and values: - learning_rate: The current learning rate. - total_steps: The total steps for the training process. - epochs: The number of epochs for training. - current_epoch: The current epoch during training. - current_iteration: The current iteration during training. - best_iteration: The iteration with the best loss value so far. - best_*: The best (lowest) optim metric value achieved during training, for example `best_val_loss` """ train_stats = OrderedDict( learning_rate=self.learning_rate, total_steps=self.total_steps, epochs=self.epochs, current_epoch=self.current_epoch, current_iteration=self.current_iteration, best_iteration=self.best_iteration, ) for k, v in self.best_metrics.items(): train_stats[f"best_{k}"] = v return train_stats @property def current_epoch(self) -> int: """ Calculates and returns the current epoch during training. Returns ------- int The current epoch number. """ if self.current_iteration > 0: return math.floor(self.current_iteration / self.total_steps * self.epochs) return 0 @cached_property def epoch_len(self) -> int: """ Returns the length of an epoch, which is the number of batches in the ``train_dataloader``. Returns ------- int The length of an epoch, represented as the number of batches in the ``train_dataloader``. Raises ------ RuntimeError If the ``train_dataloader`` is not defined or its length is 0. """ if not hasattr(self, "train_dataloader") or len(self.train_dataloader) == 0: raise RuntimeError("Undefined train_dataloader.") return len(self.train_dataloader) @cached_property def eval_itr(self) -> int: """ Calculate the interval between evaluations. Returns ------- int The interval between evaluations. """ return math.ceil(self.run_config.eval_epoch * self.epoch_len) @cached_property def log_itr(self) -> int: """ Calculate the interval between logging steps. Returns ------- int The interval between logging steps. """ return math.ceil(self.run_config.log_epoch * self.epoch_len) @property def uid(self) -> str: """ Returns a unique identifier (UID) for the current run configuration. Returns ------- str A string representing the unique identifier of the current run configuration. """ return getattr(self, "_uid", self.run_config.uid)
[docs] @abstractmethod def create_model( self, save_dict: dict[str, ty.Any] | None = None, strict_load: bool = True, ) -> None: """ Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the ``create_model`` method in the ``ModelWrapper`` class. Parameters ---------- save_dict : dict[str, ty.Any] | None A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default ``None``. strict_load : bool If True, the model will be loaded strictly, ensuring that the saved state matches the model's structure exactly. If False, the model can be loaded with a partially matching state, by default ``True``. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def checkpoint(self, is_best: bool = False): """ Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the ``checkpoint`` method in the ``ModelWrapper`` class. Parameters ---------- is_best : bool Indicates if the current checkpoint is the best model so far, by default ``False``. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def train( self, run_config: RunConfig, smoke_test: bool = False, ): """ Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the ``train`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. smoke_test : bool Whether to run as a smoke test, by default ``False``. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def evaluate( self, run_config: RunConfig, ): """ Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the ``evaluate`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def make_dataloaders(self, run_config: RunConfig): """ Abstract method to create dataloaders for the training, validation, and testing datasets. This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided ``run_config``. Must be implemented by subclasses. Example implementation: Please see the ``make_dataloaders`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def config_parser(self, run_config: RunConfig): """ Abstract method to parse the provided configuration. Must be implemented by subclasses. Example implementation: Please see the ``make_dataloaders`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
def _config_parser(self, run_config: RunConfig) -> RunConfig: """ Internal method to process a run_config. It is called internally and wraps `config_parser` Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. Returns ------- RunConfig The processed configuration. """ return self.config_parser(run_config) def _init_logger(self, resume: bool = False, debug: bool = False): """ Initializes the logger used for recording experiment details and progress. Parameters ---------- resume : bool If True, the logger will resume logging from a previous experiment, by default ``False``. debug : bool If True, no artifacts will be saved by the ``SummaryLogger``, by default ``False``. """ self.logger = SummaryLogger( run_config=self.run_config, experiment_dir=self.experiment_dir if not debug else None, resume=resume, keep_n_checkpoints=self.run_config.keep_n_checkpoints, verbose=self.run_config.verbose == "console", ) if butils.debugger_is_active() and not debug: self.logger.warn("Debug flag is False but running debugger.") elif debug: self.logger.warn("Debug flag is True, will not save any checkpoints.") if self.experiment_dir is not None: self.logger.info(f"Model directory: {self.experiment_dir}") def _make_dataloaders( self, run_config: RunConfig, data_lock: ty.Optional[Lock] = None ): """ Creates the data loaders for the training process. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. data_lock: ty.Optional[Lock] A lock for multiprocessing context that prevents simultaneous processing and downloading of the dataset, by default ``None``. """ context_lock: ty.Union[nullcontext, Lock] if data_lock is None: context_lock = nullcontext() else: context_lock = data_lock with context_lock: self.make_dataloaders(run_config) assert ( len(self.train_dataloader) > 0 ), "Must define a train dataloader in `make_dataloader`" def _parse_optim_metrics( self, run_config: RunConfig ) -> tuple[Optim, str] | tuple[None, None]: """ parses the optimization metrics and their direction to validate they meet several training constraints. For example, the scheduler optimization mode should be aligned with the configuration optimization mode. Other configurations such as EarlyStopping also depend on the optimization metrics. Parameters ---------- run_config : RunConfig The configuration to parse Returns ------- tuple[Optim, str] | tuple[None, None] returns the optimization direction and metric name or a tuple of None if they are unspecified. Raises ------ ValueError is raised when the optimization metrics are incompatible with other user configurations. """ scheduler_config = run_config.train_config.scheduler_config optim_metric_name = run_config.optim_metric_name optim_metrics = run_config.optim_metrics missing_metrics = [ run_config.optim_metrics is None, run_config.optim_metric_name is None, ] if any(missing_metrics) and not all(missing_metrics): raise ValueError( "Invalid configuration. Must specify both `optim_metrics` and `optim_metric_name` or neither." ) optim_metric_name = str(optim_metric_name) if ( optim_metric_name is not None and optim_metrics is not None and optim_metric_name not in optim_metrics ): raise ValueError( f"optim_metric_name={optim_metric_name} " f"was not found in optim_metrics={optim_metrics}" ) if optim_metric_name is not None and optim_metrics is not None: optim_direction = optim_metrics[optim_metric_name] scheduler_requires_metric = ( scheduler_config is not None and hasattr(scheduler_config, "arguments") and hasattr(scheduler_config.arguments, "mode") ) if all(missing_metrics) and scheduler_requires_metric: raise ValueError( f"Must provide `optim_metrics` when using Scheduler = `{getattr(scheduler_config,'name', 'N/A')}`." ) if all(missing_metrics) and run_config.early_stopping_iter is not None: raise ValueError( f"Must provide `optim_metrics` when using early_stopping_iter = `{run_config.early_stopping_iter}`." ) if all(missing_metrics): return None, None if scheduler_requires_metric: mode = scheduler_config.arguments.mode # type: ignore[union-attr] if (direction := optim_direction.value) != mode: self.logger.warn( f"Different optim_metric_direction {direction} than " f"scheduler.arguments.mode {mode}. Overwriting scheduler.arguments.mode." ) return optim_direction, optim_metric_name def _init_class_attributes(self): """ Initializes the class attributes based on the provided configuration. This function sets up various class attributes related to device, mixed precision, warnings handling, early stopping, metrics, experiment and model directories, and process title. """ run_config = self.run_config self.device = butils.parse_device(run_config.device) self.optim_metric_direction, self.optim_metric_name = self._parse_optim_metrics( run_config ) if self.optim_metric_direction is not None: self.best_metrics = { self.optim_metric_name: ( float("inf") if self.optim_metric_direction == Optim.min else float("-inf") ) } else: self.best_metrics = {} self.amp = run_config.amp if self.device == "cpu" and self.amp: self.logger.warn( "Automatic Mixed Precision (AMP) is not supported for CPU. Setting `amp` to False." ) self.amp = False if (batch_lim := run_config.metrics_n_batches) > len( self.train_dataloader ) * 0.2: self.logger.warn( f"Metrics batch-limit {batch_lim} is larger than " f"20% of the train dataloader length {len(self.train_dataloader)}. " "You might experience slow-down during training. Consider decreasing `metrics_n_batches`." ) self._autocast = torch.autocast( enabled=self.amp, device_type="cuda" if "cuda" in self.device else "cpu", ) self.verbose = run_config.verbose if self.verbose == "silent": warnings.filterwarnings("ignore") if ( run_config.early_stopping_iter is not None and run_config.early_stopping_iter > 0 ): assert ( self.val_dataloader is not None ), "dataloader function has to return validation set when setting early stopping to True" self.train_metrics = Metrics( batch_limit=run_config.metrics_n_batches, memory_limit=int(run_config.metrics_mb_limit * 1e6), moving_average_limit=self.epoch_len, evaluation_functions=self.evaluation_functions(), static_aux_metrics=self.train_stats, moving_aux_metrics=["loss"], ) if self.val_dataloader is not None: self.eval_metrics = Metrics( batch_limit=None, memory_limit=int(run_config.metrics_mb_limit * 1e6), moving_average_limit=None, evaluation_functions=self.evaluation_functions(), moving_aux_metrics=["loss"], ) setproctitle.setproctitle(self.uid) def _init_model_state( self, resume: bool = False, smoke_test: bool = False, from_chkpt: str | Path | None = None, ): """ Initializes the model state based on provided parameters and configuration. Parameters ---------- resume : bool If True, tries to resume training from a checkpoint, by default ``False``. smoke_test : bool Whether to run as a smoke test, by default ``False``. from_chkpt: str | Path | None, optional Path to the checkpoint to initialize the state from. Raises ------ RuntimeError If the directory containing checkpoints is not found. """ if from_chkpt is not None: self.current_checkpoint = Path(from_chkpt) self._load_model(self.current_checkpoint, model_only=False) elif self.run_config.init_chkpt is not None and not resume: # Loads only the weights self.current_checkpoint = Path(self.run_config.init_chkpt) self.logger.info( f"Initializing model weights ONLY from checkpoint. {self.current_checkpoint}" ) self._load_model(self.current_checkpoint, model_only=True) elif resume and not smoke_test: if "recent" not in self.logger.CHKPT_DIRS: raise RuntimeError("Checkpoint folder was not found.") recent_checkpoint_dir = self.logger.CHKPT_DIRS["recent"] # NOTE: current_checkpoint is found in _find_load_valid_checkpoint self._find_load_valid_checkpoint(recent_checkpoint_dir) else: self.current_checkpoint = None if not smoke_test: self.logger.info("Creating new model") self.create_model() self._update_save_dict()
[docs] def init_state( self, run_config: RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False, remote_progress_bar: ty.Optional[RemoteProgressBar] = None, from_chkpt: Path | str | None = None, data_lock: ty.Optional[Lock] = None, ): """ Initializes the state of the wrapper based on provided configuration and parameters. The lazy-initialization of the wrapper is neccessary, as the wrapper must be pickable. Initializing some objects (e.g. Dataloaders) leads to the opposite. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. smoke_test : bool Whether to run as a smoke test, by default ``False``. debug : bool If True, disables logging and model directory creation, by default ``False``. resume : bool If True, tries to resume training from a checkpoint, by default ``False``. remote_progress_bar : ty.Optional[RemoteProgressBar] A remote progress bar can be used to report metrics from the internal progress bar from_chkpt: Path | str | None, optional Path to the checkpoint to initialize the state from. data_lock: ty.Optional[Lock], optional Use a Lock to avoid downloading data concurrently. Raises ------ RuntimeError if the state is already initialized and `smoke_test`, `debug` and `resume` flag are ``False`` """ if ( self._is_init and not (smoke_test or debug or resume) # Check if the wrapper was previously initialized in dummy mode (e.g. debug or smoke_test) and not isinstance(getattr(self, "logger", butils.Dummy()), butils.Dummy) ): raise RuntimeError(f"{self.__class__.__name__} is already initialized. ") if self._is_init: self._reset_cached_attributes() self._init_attrs() self._is_init = True self.run_config = copy.deepcopy(run_config) self.run_config._unfreeze() # pylint: disable=protected-access self.random_seed = self.run_config.random_seed if self.random_seed is not None: butils.set_seed(self.random_seed) self._make_dataloaders(self.run_config, data_lock=data_lock) self.run_config = self._config_parser(self.run_config) v = self.run_config.train_config.epochs self.__setattr__internal("epochs", v) if self.run_config.experiment_dir is not None: self.experiment_dir = ( Path(self.run_config.experiment_dir).resolve().absolute() ) self.run_config.experiment_dir = self.experiment_dir.as_posix() self.run_config.assert_unambigious() self.run_config.freeze() # Does not create log artifacts during smoke test if not smoke_test: self._init_logger(resume=resume, debug=debug) else: self.logger = butils.Dummy() self._init_class_attributes() if debug and self.experiment_dir is not None: self.logger.warn( f"Experiment Directory specified {self.experiment_dir} while running on debug mode. " "If saving artifacts is unnecessary you can disable the file system by setting " "`run_config.experiment_dir=None`. " ) self._init_model_state(resume, smoke_test or debug, from_chkpt=from_chkpt) if self.verbose == "progress" and not smoke_test: self.progress_bar = ProgressBar( epoch_len=self.epoch_len, total_steps=self.total_steps, logfile=self.logger.log_file_path, remote_display=remote_progress_bar, uid=self.uid, ) else: self.progress_bar = butils.Dummy()
def _find_load_valid_checkpoint(self, chkpt_dir: Path): """ Finds and loads the latest valid checkpoint from the given directory. Parameters ---------- chkpt_dir : Path The directory containing the checkpoints. Raises ------ CheckpointNotFoundError If no valid checkpoint is found in the specified directory. RuntimeError If a checkpoint is not found. """ latest_checkpoints = butils.get_latest_chkpts(chkpt_dir) current_checkpoint = None if len(latest_checkpoints) > 0: # Try to load first valid chkpt in case there was a crash and some checkpoint is unrecoverable for i, _checkpoint in enumerate(latest_checkpoints): try: self.logger.info(f"Loading checkpoint {_checkpoint}") self._load_model(_checkpoint, model_only=False) current_checkpoint = _checkpoint break # pylint: disable=broad-exception-caught except Exception as e: if i == len(latest_checkpoints) - 1: # if it is the last checkpoint raise exception raise RuntimeError("Checkpoint not found") from e # ignore exception self.logger.error( f"Error loading checkpoint {_checkpoint}. Trying another....\n{traceback.format_exc()}" ) if current_checkpoint is None: raise CheckpointNotFoundError( f"Could not find a valid checkpoint in {chkpt_dir}" ) self.current_checkpoint = current_checkpoint def _load_model(self, checkpoint_path: Path, model_only: bool = False): """ Loads the model and its state from the checkpoint file at the specified path. Parameters ---------- checkpoint_path : Path The path to the checkpoint file containing the model and its state. model_only : bool If True, only the model's weights will be loaded, ignoring other state information, by default ``False``. Raises ------ NotImplementedError If the model's run configuration is not initialized before attempting to load the model. RuntimeError If no valid checkpoint was found, such as an invalid path, and when ``model_only=True`` we check for differences between loaded and current configuration. """ if not hasattr(self, "run_config") or self.run_config is None: raise NotImplementedError( "Can not load model on an uninitialized model state. Consider run init_experiment_state function first" ) try: save_dict = torch.load(checkpoint_path, map_location="cpu") except Exception as e: raise RuntimeError( f"{checkpoint_path} is not a valid checkpoint e.g. a `.pt` file. " ) from e if model_only: self.load_checkpoint(save_dict, model_only=model_only) return run_config = type(self.run_config)(**save_dict["run_config"]) if run_config.uid != self.run_config.uid: diffs = "\n\t".join( run_config.diff_str(self.run_config, ignore_stateless=True) ) raise RuntimeError( f"Mismatching loaded and current configurations. \n{diffs}" ) diffs = "\n\t".join( run_config.diff_str(self.run_config, ignore_stateless=False) ) if len(diffs) > 0: self.logger.warn( f"Differences between initial configuration and current configuration. \n{diffs}" ) self._load_stats(save_dict) self.load_checkpoint(save_dict, model_only=model_only) self.current_state = save_dict
[docs] @abstractmethod def load_checkpoint(self, save_dict: dict[str, ty.Any], model_only: bool = False): """ Abstract method to load the model and its state from a given save dictionary. Must be implemented by subclasses. Example implementation: Please see the ``load_checkpoint`` method in the ``ModelWrapper`` class. Parameters ---------- save_dict : dict[str, ty.Any] A dictionary containing the saved model state and other necessary information. model_only : bool If ``True``, only the model's weights will be loaded, ignoring other state information, by default ``False``. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def save_dict(self) -> dict[str, ty.Any] | None: """ Abstract method to create and return a save dictionary containing the model's state and other necessary information. Must be implemented by subclasses. Example implementation: Please see the ``save_dict`` method in the ``ModelWrapper`` class. Returns ------- dict[str, ty.Any] | None A dictionary containing the saved model state and other necessary information. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
[docs] @abstractmethod def evaluation_functions(self) -> dict[str, Callable] | None: """ Abstract method to create and return a dictionary of evaluation functions used during training and validation. Must be implemented by subclasses. Example implementation: Please see the ``evaluation_functions`` method in the ``ModelWrapper`` class. Returns ------- dict[str, Callable] | None A dictionary containing evaluation functions as values and their names as keys. Raises ------ NotImplementedError If this method is not implemented by the subclasses. """ raise NotImplementedError
def __setattr__internal(self, k, v): super().__setattr__(k, v) def __setattr__(self, k, v): try: _derived_stats_names = super().__getattribute__("_derived_stats_names") except AttributeError: _derived_stats_names = [] try: _overridable_stats_names = super().__getattribute__( "_overridable_stats_names" ) except AttributeError: _overridable_stats_names = [] if k in _derived_stats_names and k not in _overridable_stats_names: raise RuntimeError(f"Can not set derived attribute {k}.") self.__setattr__internal(k, v) def __getattribute__(self, name, _ignore_init=False): if name.startswith("_") or super().__getattribute__("_is_init"): return super().__getattribute__(name) try: _init_function_names = super().__getattribute__("_init_function_names") except AttributeError: _init_function_names = [] if name not in _init_function_names and not _ignore_init: raise RuntimeError( f"Can not read property {name} of unitialized {self.__class__.__name__}. " "It must be initialized with `init_state` before using." ) return super().__getattribute__(name) def _load_stats(self, save_dict: dict) -> None: """ Loads the saved training and validation metrics from the save_dict and updates the model's internal metrics with the loaded values. Parameters ---------- save_dict : dict A dictionary containing the saved model state, metrics, and other necessary information. """ metrics = copy.deepcopy(save_dict["train_metrics"]) best_metrics = [f"best_{k}" for k in self.best_metrics] for k in self.train_stats: if k in self._running_stats_names: continue if k in self._derived_stats_names: # We skip assigning this metric as it will be derived by other metrics. if getattr(self, k, None) != metrics[k]: self.logger.warn( f"Current attribute {k} value derived to {getattr(self, k)} and " f"is different than loaded value {metrics[k]}. Will use the current value." ) del metrics[k] for k in self.train_stats: if k in metrics and k in best_metrics: self.best_metrics[k.lstrip("best_")] = metrics[k] del metrics[k] elif k in metrics: setattr(self, k, metrics[k]) del metrics[k] self.train_metrics.update_static_metrics(self.train_stats) # pylint: disable=protected-access self.train_metrics._update_ma_metrics(metrics) if "eval_metrics" in save_dict and self.eval_metrics is not None: metrics = copy.deepcopy(save_dict["eval_metrics"]) # pylint: disable=protected-access self.eval_metrics._update_ma_metrics(metrics) def _update_save_dict(self, user_save_dict: dict[str, ty.Any] | None = None): """ Updates the current state dictionary with run_config and metrics. If a user_save_dict is provided, it is also merged into the current state dictionary. Parameters ---------- user_save_dict : dict[str, ty.Any] | None A dictionary containing user-defined information to be saved, by default ``None``. """ self.current_state = { "run_config": self.run_config.to_dict(), "train_metrics": self.train_metrics.to_dict(), } if self.eval_metrics is not None: self.current_state["eval_metrics"] = self.eval_metrics.to_dict() if user_save_dict is not None: self.current_state.update(**user_save_dict) def _checkpoint(self, is_best: bool = False): """ Updates the current state dictionary with user-defined save_dict and calls the checkpoint method. Parameters ---------- is_best : bool Indicates if the current checkpoint is the best model so far, by default ``False``. """ user_save_dict = self.save_dict() self._update_save_dict(user_save_dict) self.checkpoint(is_best=is_best)