import inspect
import time
import traceback
import typing as ty
from abc import abstractmethod
from collections.abc import Callable, Iterable
from functools import cached_property
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.cuda.amp import GradScaler
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import ablator.utils.base as butils
from ablator.config.proto import ModelConfig, Optim, RunConfig, TrainConfig
from ablator.main.model.main import EvaluationError, ModelBase, TrainPlateauError
from ablator.modules.metrics.main import LossDivergedError, Metrics
from ablator.modules.optimizer import OptimizerConfig
from ablator.modules.scheduler import Scheduler, SchedulerConfig
# pylint: disable=too-many-public-methods
# pylint: disable=too-many-instance-attributes
[docs]class ModelWrapper(ModelBase):
"""
A wrapper around ``model_class`` that removes training boiler-plate code. Once ``make_dataloader_train`` is
overridden to provide a training dataset, you can pass ``ModelWrapper`` object to the trainers
(``ProtoTrainer`` or ``ParallelTrainer``) along with a running configuration (``RunConfig`` or
``ParallelConfig``) to launch the experiment.
The wrapper lets you customize the training process by allowing the overriding of various functions, this makes
it adaptable to different training paradigms. Several customizing use cases are shown below.
Attributes
----------
model_class : torch.nn.Module
The model class to wrap.
model : torch.nn.Module
The model created from the model class or checkpoint
optimizer : Optimizer
The optimizer created from the optimizer config or checkpoint
scaler : GradScaler
The scaler created from the scaler config or checkpoint
scheduler : Scheduler
The scheduler created from the scheduler config or checkpoint
Parameters
----------
model_class : type[nn.Module]
The model class to wrap.
"""
def __init__(
self,
model_class: type[nn.Module],
):
super().__init__(
model_class=model_class,
)
# Will be loaded or created from checkpoint
self.model: nn.Module
self.optimizer: Optimizer
self.scaler: GradScaler
self.scheduler: Scheduler | None
self._prev_update_time: float = time.time()
self._derived_stats_names += [
"train_config",
"model_config",
]
self._overridable_stats_names = ["epochs"]
self._init_function_names += ["train", "evaluate"]
self._cached_properties += [
"_scheduler_step_when",
"_scheduler_requires_metric",
]
self._is_partially_optimized: None | bool = None
self._is_self_optim: None | bool = None
@property
def train_config(self) -> TrainConfig:
return self.run_config.train_config
@property
def model_config(self) -> ModelConfig:
return self.run_config.model_config
[docs] def create_model(
self,
save_dict: dict[str, ty.Any] | None = None,
strict_load: bool = True,
):
"""
Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from configuration
objects. You can overwrite this function and ``save_dict()`` function to customize the saving and loading of
the model, optimizer, and scheduler to your needs. An example for this is shown in `Saving and loading
multi-module models <./notebooks/Multi-Modules.ipynb>`_ tutorial.
Parameters
----------
save_dict : dict[str, ty.Any] | None
The saved checkpoint dictionary to load from, by default ``None``.
strict_load : bool
Whether to throw an error for mismatched keys, by default ``True``.
"""
save_dict = {} if save_dict is None else save_dict
scheduler_state = save_dict["scheduler"] if "scheduler" in save_dict else None
optimizer_state = save_dict["optimizer"] if "optimizer" in save_dict else None
scaler_state = save_dict["scaler"] if "scaler" in save_dict else None
model_class = self.model_class
model: nn.Module
if (model_config := self.model_config) is not None:
model = model_class(model_config)
else:
# Support of declarative paradigm without model over-writing
model = model_class()
if "model" in save_dict:
model.load_state_dict(save_dict["model"], strict=strict_load)
elif inspect.ismethod(getattr(model, "init_weights", None)):
# TODO tutorial on this use-case
getattr(model, "init_weights")()
model = model.to(self.device)
optimizer = self.create_optimizer(
model=model,
optimizer_config=self.train_config.optimizer_config,
optimizer_state=optimizer_state,
)
scheduler = self.create_scheduler(
model=model,
optimizer=optimizer,
scheduler_config=self.train_config.scheduler_config,
scheduler_state=scheduler_state,
)
scaler = self.create_scaler(scaler_state=scaler_state)
self.model = model
self.optimizer = optimizer
self.scaler = scaler
self.scheduler = scheduler
[docs] def create_scheduler(
self,
model: nn.Module,
optimizer: Optimizer,
scheduler_config: SchedulerConfig | None = None,
scheduler_state: dict[str, ty.Any] | None = None,
) -> ty.Optional[Scheduler]:
"""
Creates the scheduler from the saved state or from config.
Parameters
----------
model : nn.Module
The model to create the scheduler for.
optimizer : Optimizer
The optimizer to create the scheduler for.
scheduler_config : SchedulerConfig | None
The scheduler config to create the scheduler from, by default ``None``.
scheduler_state : dict[str, ty.Any] | None
The scheduler state to load the scheduler from, by default ``None``.
Returns
-------
ty.Optional[Scheduler]
The scheduler.
"""
scheduler: ty.Optional[Scheduler] = None
if scheduler_config is not None:
scheduler = scheduler_config.make_scheduler(model, optimizer)
if scheduler_state is not None:
if scheduler is None:
self.logger.warn(
"Supplied `scheduler_state` without `scheduler_config`. Ignoring scheduler."
)
return None
scheduler.load_state_dict(scheduler_state)
return scheduler
[docs] def create_optimizer(
self,
model: nn.Module,
optimizer_config: OptimizerConfig | None = None,
optimizer_state: dict[str, ty.Any] | None = None,
) -> Optimizer:
"""
Creates the optimizer from the saved state or from config.
Parameters
----------
model : nn.Module
The model to create the optimizer for.
optimizer_config : OptimizerConfig | None
The optimizer config to create the optimizer from, by default ``None``.
optimizer_state : dict[str, ty.Any] | None
The optimizer state to load the optimizer from, by default ``None``.
Returns
-------
Optimizer
The optimizer.
"""
optimizer: Optimizer
if optimizer_config is not None:
optimizer = optimizer_config.make_optimizer(model)
if optimizer_state is not None and optimizer is not None:
# NOTE: because https://github.com/pytorch/pytorch/issues/80809
# TODO any good fix for this yet?
for k in optimizer_state["state"].keys():
if "step" in optimizer_state["state"][k] and isinstance(
optimizer_state["state"][k]["step"], torch.Tensor
):
optimizer_state["state"][k]["step"] = optimizer_state["state"][k][
"step"
].cpu()
optimizer.load_state_dict(optimizer_state)
elif optimizer_state is not None:
self.logger.warn(
"Supplied `optimizer_state` without `optimizer_config`. Ignoring optimizer."
)
return optimizer
[docs] def create_scaler(
self, scaler_state: ty.Optional[dict[str, ty.Any]] = None
) -> GradScaler:
"""
Creates the scaler from the saved state or from config.
Parameters
----------
scaler_state : ty.Optional[dict[str, ty.Any]]
The scaler state to load the scaler from, optional, by default ``None``.
Returns
-------
GradScaler
The scaler.
"""
scaler = GradScaler(enabled=self.amp)
if scaler_state:
scaler.load_state_dict(scaler_state)
return scaler
[docs] def load_checkpoint(self, save_dict: dict[str, ty.Any], model_only: bool = False):
"""
Loads the checkpoint from the save dict.
Parameters
----------
save_dict : dict[str, ty.Any]
The save dict to load the checkpoint from.
model_only : bool
Whether to load only the model or include scheduler, optimizer and scaler, by default ``False``.
Notes
-----
This method is the implementation of the abstract method in the base class.
"""
if model_only:
del save_dict["scheduler"]
del save_dict["optimizer"]
del save_dict["scaler"]
self.create_model(
save_dict,
strict_load=True,
)
[docs] def to_device(self, data: Iterable, device: ty.Optional[str] = None) -> Iterable:
"""
Moves the data to the specified device.
Parameters
----------
data : Iterable
The data to move to the device.
device : ty.Optional[str]
The device to move the data to. If ``None``, the device specified in the config is used.
Returns
-------
Iterable
The data on the device.
"""
if device is None:
device = self.device
return butils.iter_to_device(data, device)
[docs] def model_step(
self, model: nn.Module, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]:
"""
A single inference step for the model.
Parameters
----------
model : nn.Module
The model to train.
batch : Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
Returns
-------
tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]
The output of the model,contains current predictions and loss of the model
"""
batch = self.to_device(batch)
if isinstance(batch, list):
out = model(*batch)
elif isinstance(batch, dict):
out = model(**batch)
else:
out = model(batch)
return out
@ty.final
def _update_learning_rate(self):
self.learning_rate = butils.get_lr(self.optimizer)
return self.learning_rate
@ty.final
def _inc_iter(self):
self.current_iteration += 1
def _is_step(self, step_interval: int) -> bool:
return (
step_interval > 0
and self.current_iteration > 0
and self.current_iteration % step_interval == 0
)
def _train_evaluation_step(self, smoke_test: bool = False):
# pylint: disable=too-complex
is_best = False
optim_metric = None
# If we are within 10% of the start or end of an epoch, we skip
# evalaution of train metrics for faster training
if (
self.current_iteration % self.epoch_len > 0.1 * self.epoch_len
and self.current_iteration % self.epoch_len < 0.9 * self.epoch_len
):
self.train_metrics.evaluate(reset=False)
if self.val_dataloader is not None and self.eval_metrics is not None:
self._validation_loop(
model=self.model,
dataloader=self.val_dataloader,
metrics=self.eval_metrics,
subsample=self.run_config.eval_subsample,
smoke_test=smoke_test,
)
else:
self.logger.warn(
"Validation dataloader and metrics were not set. Will be skipping `validation_loop`."
)
if (
self.optim_metric_name is not None
and self.optim_metric_name not in self.metrics
):
metric_names = sorted(list(self.metrics.keys()))
raise RuntimeError(
f"optim_metric_name=`{self.optim_metric_name}` not found in metrics {metric_names}. "
"Make sure your validation loader and validation loop are configured correctly."
)
if self.optim_metric_name is not None:
# Use val loss for scheduling or finding best checkpoint
optim_metric = self.metrics[self.optim_metric_name]
optim_direction = self.optim_metric_direction
best_metric = self.best_metrics[self.optim_metric_name]
is_warmup = (
self.current_iteration
<= self.epoch_len * self.run_config.warm_up_epochs
)
if is_best := (
(optim_direction == Optim.min and optim_metric < best_metric)
or (optim_direction == Optim.max and optim_metric > best_metric)
):
self.best_iteration = self.current_iteration
self.best_metrics[self.optim_metric_name] = optim_metric
best_metric = optim_metric
elif not is_warmup:
relative_change_a = (optim_metric - best_metric + 1e-5) / abs(
best_metric + 1e-5
)
relative_change_b = (optim_metric - best_metric + 1e-5) / abs(
optim_metric + 1e-5
)
ratio = abs(relative_change_a + relative_change_b) / 2
div_factor = self.run_config.divergence_factor
if div_factor is not None and (ratio > div_factor):
raise LossDivergedError(
f"Val {self.optim_metric_name} {optim_metric:.2e} has diverged by "
f"a factor larger than {self.run_config.divergence_factor:0.0f} to "
f"best_{self.optim_metric_name} {best_metric:.2e}"
)
elif self.scheduler is not None and self._scheduler_requires_metric:
raise EvaluationError(
f"A validation optimization argument is required with "
f"{self.scheduler.__class__.__name__} scheduler. "
"Try setting a `optim_metric_name`"
)
if self._scheduler_requires_metric:
self.scheduler_step(optim_metric, is_val_step=True)
else:
self.scheduler_step(is_val_step=True)
self.update_status()
self._checkpoint()
if is_best:
self._checkpoint(is_best=True)
# Early stopping
early_stopping_iter = self.run_config.early_stopping_iter
if (
early_stopping_iter is not None
and self.best_iteration is not None
and (self.current_iteration - self.best_iteration) > early_stopping_iter
):
diff = self.current_iteration - self.best_iteration
raise TrainPlateauError(
f"Early stopping. No improvement for {diff} > early_stopping_iter = `{early_stopping_iter}` iterations."
)
def _model_step(
self, model: nn.Module, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]:
with self._autocast:
out = self.model_step(model=model, batch=batch)
try:
outputs, loss = out
assert isinstance(outputs, (dict, type(None))) and isinstance(
loss, (torch.Tensor, type(None))
)
if outputs is not None:
for k, v in outputs.items():
assert isinstance(k, str) and isinstance(v, torch.Tensor)
except Exception as exc:
raise RuntimeError(
"Model should return outputs: dict[str, torch.Tensor] | None, loss: torch.Tensor | None."
) from exc
return outputs, loss
[docs] @ty.final
def train_step(
self, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, dict[str, ty.Any]]:
"""
A single step for training.
It also updates learning rate with scheduler.
Parameters
----------
batch : Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
Returns
-------
tuple[dict[str, torch.Tensor] | None, dict[str, ty.Any]]
- outputs : dict[str, torch.Tensor] | None
The output of the model.
- train_metrics: dict[str, ty.Any]
The training metrics.
"""
model = self.model
optimizer = self.optimizer
scaler = self.scaler
# Ensure no left-over grads are in the model's parameters from custom evaluation or what-not
optimizer.zero_grad(set_to_none=True)
outputs, loss = self._model_step(model=model, batch=batch)
loss_value = self.apply_loss(model, loss, optimizer, scaler)
self.scheduler_step()
self._inc_iter()
self._update_learning_rate()
train_metrics = {}
if loss is not None:
train_metrics["loss"] = loss_value
return outputs, train_metrics
[docs] def log_step(self):
"""
A single step for logging.
Notes
-----
This method is update the logger with the current metrics and log a status message.
"""
self.logger.update(self.metrics)
self.update_status()
msg = self.status_message()
verbose = self.verbose == "console"
self.logger.info(msg, verbose=verbose)
[docs] def update_status(self) -> None:
"""
Update the metrics with current training stats,
and then all metrics (static and moving average) will be set as description for the ``tqdm`` progress.
"""
self.train_metrics.update_static_metrics(self.train_stats)
if self.verbose != "progress":
return
self.progress_bar.update_metrics(
self.metrics, self.current_iteration % self.epoch_len
)
@property
def metrics(self) -> dict[str, float]:
"""
The metrics of the current training state.
If ``eval_metrics`` are defined (e.g. a validation dataloader was provided)
then they are combined and a `val_` and `train_` label is prepended.
The current training statistics are also combined
Returns
-------
dict[str, float]
A dictionary with keys the metric name and the value corresponding to the metric.
"""
metrics = {}
if self.eval_metrics is not None:
for k, v in self.eval_metrics.to_dict().items():
metrics[f"val_{k}"] = v
for k, v in self.train_metrics.to_dict().items():
_k = f"train_{k}" if f"val_{k}" in metrics else k
metrics[_k] = v
return metrics
[docs] def status_message(self) -> str:
"""
Return a string generated from dictionary of current metrics,
including all the static metrics and moving average metrics.
Returns
-------
str
The status message.
"""
# must return current epoch, iter, losses and metrics
msg_str = ""
for k, v in self.metrics.items():
msg_str += f"{k}: {butils.num_format(v)} "
return msg_str.strip()
[docs] def log(self):
"""
Log if the current iteration is a logging step. It also evaluate training metrics for logging.
"""
# Log step
update_interval = 1
if self._is_step(self.log_itr):
self.train_metrics.evaluate(reset=False)
self.log_step()
elif (
self.verbose == "progress"
and time.time() - self._prev_update_time > update_interval
):
self.update_status()
self._prev_update_time = 1
[docs] @ty.final
def eval(self, smoke_test: bool = False):
"""
Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step.
It also check if it is early stopping (check Model Configuration module for more details).
Parameters
----------
smoke_test : bool
If True, for a smoke test, by default ``False``.
Raises
------
LossDivergedError
If the loss is diverged during eval.
TrainPlateauError
If loss begins to slow down dramatically.
"""
# Evaluation step
if self._is_step(self.eval_itr):
try:
self._train_evaluation_step(smoke_test=smoke_test)
except (LossDivergedError, TrainPlateauError) as e:
error = traceback.format_exc()
self.logger.error(error)
raise e
finally:
eval_step = (
self.current_iteration
if self.eval_itr == 0
else self.current_iteration // self.eval_itr
)
msg = self.status_message()
self.logger.info(f"Evaluation Step [{eval_step}] {msg}", verbose=False)
@property
def total_steps(self) -> int: # type: ignore[override]
"""
The total number of steps for training.
Returns
-------
int
total steps = epoch's length * number of epochs.
"""
return self.epoch_len * self.epochs
[docs] def train_loop(self, smoke_test: bool = False) -> dict[str, float]:
"""
Train the model in many steps, evaluate the model and log the metrics for each iteration.
metrics including static metrics like learning rate, along with validation and
training metrics like loss and mean.
Parameters
----------
smoke_test : bool
Whether to run a smoke test.
Returns
-------
dict[str, float]
metrics after completion of training.
Raises
------
ValueError
If model outputs are not stored in correct format.
LossDivergedError
If the loss diverged.
"""
train_dataloader = self.train_dataloader
generator = iter(train_dataloader)
for i in range(self.current_iteration, self.total_steps):
self.model.train()
try:
batch = next(generator)
except StopIteration:
# restart the generator if the previous generator is exhausted.
generator = iter(train_dataloader)
batch = next(generator)
self.train_metrics.evaluate()
self.progress_bar.reset()
outputs, moving_metrics = self.train_step(batch)
if outputs is not None:
try:
self.train_metrics.append_batch(**outputs)
except ValueError as e:
raise ValueError(
"Can not store model outputs to metrics. "
"Make sure that the model outputs are formatted correctly. "
) from e
self.train_metrics.update_ma_metrics(moving_metrics)
if "loss" in moving_metrics and not np.isfinite(moving_metrics["loss"]):
msg = f"Loss Diverged. Terminating. loss: {moving_metrics['loss']}"
self.logger.error(msg)
raise LossDivergedError(msg)
if not smoke_test:
self.eval()
self.log()
if smoke_test and i > self.epoch_len * 0.01:
self.eval(smoke_test=True)
break
return self.metrics
[docs] @ty.final
def train(
self,
run_config: RunConfig | None = None,
smoke_test: bool = False,
debug: bool = False,
resume: bool = False,
) -> dict[str, float]:
"""
Initialize states and train the model.
When keyboard interrupts, saves a checkpoint
Parameters
----------
run_config : RunConfig | None
The run config to use for training, by default ``None``.
smoke_test : bool
Whether to run a smoke test, by default ``False``.
debug : bool
Whether to run in debug mode, by default ``False``.
resume : bool
Whether to resume training the model from existing checkpoints and
existing experiment state, by default ``False``.
Returns
-------
dict[str, float]
The metrics from the training.
Raises
------
ValueError
if the state is not initialized and no ``run_config`` is provided or when a ``run_config`` is
provided but the state is already initialized.
"""
if not self._is_init and run_config is not None:
self.init_state(
run_config=run_config,
smoke_test=smoke_test,
debug=debug,
resume=resume,
)
elif not self._is_init:
raise ValueError(
f"{self.__class__.__name__} is not initialized. Must provide a `run_config`."
)
elif debug or smoke_test:
self.init_state(
run_config=self.run_config,
smoke_test=smoke_test,
debug=debug,
resume=resume,
)
elif run_config is not None:
raise ValueError(
f"Can not provide `run_config` to already initialized `{self.__class__.__name__}`"
)
if self.current_iteration == self.total_steps:
self.logger.warn(
f"Training is already complete: {self.current_iteration} / {self.total_steps}. "
"Returning current metrics."
)
return self.metrics
try:
return self.train_loop(smoke_test)
except KeyboardInterrupt:
self._checkpoint()
finally:
self.progress_bar.close()
msgs = (
[]
if isinstance(self.progress_bar, butils.Dummy)
else self.progress_bar.make_metrics_message(self.metrics)
)
for msg in msgs:
self.logger.info(msg, verbose=True)
return self.metrics
[docs] @ty.final
def evaluate(
self, run_config: RunConfig, chkpt: str | Path | None = None
) -> dict[str, dict[str, ty.Any]]:
"""
Evaluate the model after training on the test and validation sets.
Parameters
----------
run_config : RunConfig
The run config to use for evaluation.
chkpt: str | Path | None
Path to the checkpoint to evaluate. If None, the latest checkpoint is evaluated, by default ``None``.
Returns
-------
dict[str, dict[str, ty.Any]]
Metrics
"""
self.init_state(run_config, resume=True, from_chkpt=chkpt)
self.logger.info(f"Evaluating {self.current_checkpoint}")
self.update_status()
msg = self.metrics
self.logger.info(f"Current metrics: {msg}")
metrics = {}
for loader, tag in zip(
[self.test_dataloader, self.val_dataloader], ["test", "val"]
):
if loader is not None:
# NOTE we set max memory limit and let it crash because we do not want
# inaccurate metrics calculation. Possibly smarter ways to go about it.
eval_metrics = Metrics(
batch_limit=None,
memory_limit=None,
moving_average_limit=None,
evaluation_functions=self.evaluation_functions(),
moving_aux_metrics=["loss"],
)
self._validation_loop(
model=self.model,
dataloader=loader,
metrics=eval_metrics,
subsample=1,
)
metrics[tag] = eval_metrics.to_dict()
msg = eval_metrics.to_dict()
self.logger.info(f"Evaluating {tag}: {msg}")
self.update_status()
return metrics
[docs] @ty.final
def backward(
self, loss: torch.Tensor, scaler: torch.cuda.amp.GradScaler | None = None
):
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
return loss.item()
[docs] @ty.final
def optim_step(
self,
optimizer: Optimizer,
scaler: torch.cuda.amp.GradScaler,
model: nn.Module,
loss: torch.Tensor | None,
):
if self.amp and loss is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
if self._is_partially_optimized is None and any(
p.grad is None
for param_group in optimizer.param_groups
for p in param_group["params"]
):
self.logger.warn(
"Not all optimization parameters contain gradients. "
"If this is expected behavior please ignore this message. "
"Otherwise make sure you are using your model correctly."
)
self._is_partially_optimized = True
elif self._is_partially_optimized is None:
self._is_partially_optimized = False
optimizer.zero_grad(set_to_none=True)
@cached_property
def _scheduler_step_when(self):
step_when = None
try:
step_when = self.train_config.scheduler_config.arguments.step_when
except AttributeError:
step_when = getattr(self.scheduler, "step_when", None)
return step_when
@cached_property
def _scheduler_requires_metric(self):
if self.scheduler is None:
return False
try:
params = inspect.signature(self.scheduler.step).parameters.keys()
return "metrics" in params
except AttributeError:
return False
[docs] @ty.final
def scheduler_step(self, *args, is_val_step=False):
"""
A single scheduler step. This function accounts for when the scheduler is supposed
to take a step. It reads `step_when` as a property from the scheduler or derives
it from the train configuration. `step_when` is expected to be in
{"train", "epoch", "val"}. When the scheduler is a validation scheduler,
it expects some mertrics passed as arguments to the scheduler step function.
e.g. `wrapper.scheduler_step(0.001)`.
When no args is provided and it is a validation scheduler, this function is
no-op.
"""
if (scheduler := self.scheduler) is None:
return
step_when = self._scheduler_step_when
if step_when == "train" or step_when is None:
scheduler.step(*args)
elif step_when == "epoch" and self._is_step(self.epoch_len):
scheduler.step(*args)
elif step_when == "val" and is_val_step:
scheduler.step(*args)
[docs] def apply_loss(
self,
model: nn.Module,
loss: torch.Tensor | None,
optimizer: Optimizer,
scaler: torch.cuda.amp.GradScaler,
) -> float | None:
"""
Calculate the loss and apply the gradients, call ``optimizer.step()`` and ``scheduler.step()``.
Parameters
----------
model : nn.Module
The model to apply the loss to.
loss : torch.Tensor | None
The loss to apply.
optimizer : Optimizer
The optimizer to step.
scaler: torch.cuda.amp.GradScaler
The scaler to use during mixed precision training.
Returns
-------
float | None
The loss value.
"""
if loss is not None:
loss = torch.mean(loss)
loss_value = self.backward(loss, scaler)
else:
if self._is_self_optim is None and all(
p.grad is None
for param_group in optimizer.param_groups
for p in param_group["params"]
):
self.logger.error(
"The loss returned by the model is `None` "
"and no optimization parameter contains gradients. "
"You need to perform optimization internally, either call `loss.backward()`"
" in the `model.forward`, or define your own optimizer to perform `optimizer.step()` "
"inside `model.forward`. "
)
self._is_self_optim = True
elif self._is_self_optim is None:
self._is_self_optim = False
loss_value = None
self.optim_step(optimizer, scaler, model, loss)
return loss_value
@torch.no_grad()
def _validation_loop(
self,
model: nn.Module,
dataloader: DataLoader,
metrics: Metrics,
subsample: float = 1.0,
smoke_test: bool = False,
) -> dict[str, float]:
was_training = model.training
model.eval()
_sampler = getattr(dataloader, "sampler", None)
is_random_sampling = isinstance(
_sampler, torch.utils.data.sampler.RandomSampler
)
if subsample < 1.0 and not is_random_sampling:
self.logger.warn(
f"Validating on a subsample=`{subsample}` without a random sampler,"
f"sampler=`{type(_sampler).__name__}`. The results can be biased. "
)
metrics_dict = self.validation_loop(
model, dataloader, metrics, subsample, smoke_test
)
if was_training:
model.train()
return metrics_dict
[docs] def validation_loop(
self,
model: nn.Module,
dataloader: DataLoader,
metrics: Metrics,
subsample: float = 1.0,
smoke_test: bool = False,
) -> dict[str, float]:
"""
Validate the model on data on dataloader and store results on metrics.
Parameters
----------
model : nn.Module
The model to validate.
dataloader : DataLoader
The dataloader to use for validation.
metrics : Metrics
The metrics to use for validation.
subsample : float
The fraction of the dataloader to use for validation, by default ``1.0``.
smoke_test : bool
Whether to execute this function as a smoke test. If ``True``, only one iteration will be performed,
which is useful for quickly checking if the code runs without errors, by default ``False``.
Returns
-------
dict[str, float]
The metrics from the validation.
"""
cutoff_itr = len(dataloader) * subsample
if model.training:
self.logger.warn(
"Called `validation_loop` without setting the model to evaluation mode. i.e. `model.eval()`"
)
for i, batch in enumerate(dataloader):
with torch.no_grad():
outputs, loss = self._model_step(model=model, batch=batch)
val_metrics = {}
if outputs is not None:
metrics.append_batch(**outputs)
if loss is not None:
val_metrics["loss"] = torch.mean(loss).item()
metrics.update_ma_metrics(val_metrics)
if i > cutoff_itr or smoke_test:
break
metrics.evaluate()
return metrics.to_dict()
[docs] @abstractmethod
def make_dataloader_train(self, run_config: RunConfig) -> DataLoader:
"""
Function to make the training dataloader. You must overwrite this function and
return the training dataloader.
Parameters
----------
run_config: RunConfig
The run configuration.
Returns
-------
DataLoader
The training dataloader.
Examples
--------
>>> class MyModelWrapper(ModelWrapper):
>>> def __init__(self, *args, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> def make_dataloader_train(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... train_dataset,
... batch_size=32,
... shuffle=True
... )
>>> def make_dataloader_val(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... test_dataset,
... batch_size=32,
... shuffle=False
... )
"""
[docs] def evaluation_functions(self) -> dict[str, Callable] | None:
"""
You can overwrite this function and return the evaluation functions' callables
that will be used to evaluate experiment metrics.
Returns
-------
dict[str, Callable] | None
The evaluation functions to use.
Examples
--------
- Define the callables:
>>> def my_accuracy(y_true, y_pred):
>>> return accuracy_score(y_true.flatten(), y_pred.flatten())
>>> def my_f1_score(y_true, y_pred):
>>> return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
- Returns callables in ``evaluation_functions``:
>>> class MyModelWrapper(ModelWrapper):
>>> def __init__(self, *args, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> def make_dataloader_train(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... train_dataset,
... batch_size=32,
... shuffle=True
... )
>>> def make_dataloader_val(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... val_dataset,
... batch_size=32,
... shuffle=False
... )
>>> def evaluation_functions(self):
>>> return {
... "accuracy": my_accuracy,
... "f1": my_f1_score
... }
- Note that the callable's parameter names must match the model's forward output. In our example,
``y_true`` and ``y_pred`` must be returned by the model's forward method to match with ``y_true``
and ``y_pred`` parameters of ``my_accuracy`` and ``my_f1_score`` functions:
>>> class MyModel(nn.Module):
>>> def __init__(self, config: CustomModelConfig) -> None:
>>> super().__init__()
>>> self.model = FashionMNISTModel(config)
>>> self.loss = nn.CrossEntropyLoss()
>>> def forward(self, x, labels=None):
>>> out = self.model(x)
>>> loss = None
>>> if labels is not None:
>>> loss = self.loss(out, labels)
>>> out = out.argmax(dim=-1)
>>> return {"y_pred": out, "y_true": labels}, loss
"""
# Functions that can be optionally over-written.
[docs] def make_dataloader_test(self, run_config: RunConfig) -> DataLoader | None:
"""
Function to make the test dataloader. You can overwrite this function and
return the test dataloader.
Parameters
----------
run_config: RunConfig
The run configuration.
Returns
-------
DataLoader | None
The test dataloader.
Examples
--------
>>> class MyModelWrapper(ModelWrapper):
>>> def __init__(self, *args, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> def make_dataloader_train(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... train_dataset,
... batch_size=32,
... shuffle=True
... )
>>> def make_dataloader_test(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... test_dataset,
... batch_size=32,
... shuffle=False
... )
"""
[docs] def make_dataloader_val(self, run_config: RunConfig) -> DataLoader | None:
"""
Function to make the validation dataloader. You can overwrite this function and
return the validation dataloader.
Parameters
----------
run_config : RunConfig
The run configuration.
Returns
-------
DataLoader | None
The validation dataloader.
Examples
--------
>>> class MyModelWrapper(ModelWrapper):
>>> def __init__(self, *args, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> def make_dataloader_train(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... train_dataset,
... batch_size=32,
... shuffle=True
... )
>>> def make_dataloader_val(self, run_config: CustomRunConfig):
>>> return torch.utils.data.DataLoader(
... val_dataset,
... batch_size=32,
... shuffle=False
... )
"""
[docs] def config_parser(self, run_config: RunConfig) -> RunConfig:
"""
You can overwrite this function to initialize ``Derived`` properties that are not decided
until the experiment is launched.
Parameters
----------
run_config : RunConfig
The run config for the experiment.
Returns
-------
RunConfig
Examples
--------
For example, in GPT2 models, sometimes we need to resize its vocabulary size to match the tokenizer's
vocabulary size depending on the tokenizer used. This is decided only after the experiment has
been launched. The reason for this is that you might want to run an ablation study on different tokenizers:
>>> class MyLMWrapper(ModelWrapper):
... def config_parser(self, run_config: RunConfig):
... run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer)
... return run_config
"""
return run_config
[docs] def make_dataloaders(self, run_config: RunConfig) -> None:
"""
This function is done post-initialization because otherwise the
dataloaders are pickled with the object when running distributed.
Parameters
----------
run_config : RunConfig
The run config for the experiment.
"""
self.train_dataloader = self.make_dataloader_train(run_config)
# pylint: disable=assignment-from-no-return
self.val_dataloader = self.make_dataloader_val(run_config)
# pylint: disable=assignment-from-no-return
self.test_dataloader = self.make_dataloader_test(run_config)
[docs] def checkpoint(self, is_best: bool = False):
"""
Save a checkpoint of the model.It will use the class name of the model as the filename.
Parameters
----------
is_best : bool
Whether this is the best model so far, by default ``False``.
"""
self.logger.checkpoint(
self.current_state,
self.model.__class__.__name__,
is_best=is_best,
itr=self.current_iteration,
)
[docs] def save_dict(self) -> dict[str, ty.Any]:
"""
Save the current state of the trainer, including model parameters, and the current states of the optimizer,
scaler, and scheduler. You can overwrite this function and ``create_model()`` to customize the saving and
loading of the model, optimizer, and scheduler to your needs. An example of this is shown in
`Saving and loading multi-module models <./notebooks/Multi-Modules.ipynb>`_ tutorial.
Returns
-------
dict[str, ty.Any]
The current state of the trainer.
- model: the model's state
- optimizer: the state of optimizer
- scheduler: the scheduler's state
- scaler: the state of gradScaler.
"""
model_state_dict = self.model.state_dict()
optimizer_state_dict = None
if getattr(self, "optimizer", None) is not None:
optimizer_state_dict = self.optimizer.state_dict()
scheduler_state_dict = None
if getattr(self, "scheduler", None) is not None:
scheduler_state_dict = self.scheduler.state_dict() # type: ignore[union-attr]
scaler_state_dict = None
if getattr(self, "scaler", None) is not None:
scaler_state_dict = self.scaler.state_dict()
return {
"model": model_state_dict,
"optimizer": optimizer_state_dict,
"scheduler": scheduler_state_dict,
"scaler": scaler_state_dict,
}