ablator.main.model package#

Submodules#

ablator.main.model.main module#

exception ablator.main.model.main.CheckpointNotFoundError[source]#

Bases: FileNotFoundError

exception ablator.main.model.main.EvaluationError[source]#

Bases: Exception

exception ablator.main.model.main.LogStepError[source]#

Bases: Exception

class ablator.main.model.main.ModelBase(model_class: type[torch.nn.modules.module.Module])[source]#

Bases: ABC

Base class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality.

Parameters:
model_classtype[nn.Module]

The base class for user’s model, which defines the neural network.

Notes

  1. Class properties are simply listed by name. Please check out property docstring for more information.

  2. Users must implement the abstract methods to customize the model’s behavior.

  3. Mixed precision training enables some operations to use the torch.float32 datatype and other operations use lower precision floating point datatype torch.float16. This is for saving time and reducing memory usage. Ordinarily, “automatic mixed precision training” means training with torch.autocast and torch.cuda.amp.GradScaler together. More information: https://pytorch.org/docs/stable/amp.html

Attributes:
model_classType[nn.Module]

The class definition of the model’s structure, which is a subclass of nn.Module.

run_configRunConfig

An instance of RunConfig containing configuration details.

train_dataloaderDataLoader

A DataLoader object responsible for model training.

val_dataloaderOptional[DataLoader]

An optional DataLoader object used for model evaluation.

test_dataloaderOptional[DataLoader]

An optional DataLoader object used for model testing.

loggerUnion[SummaryLogger, Dummy]

Records information on the program’s operation and model training, such as progress and performance metrics.

devicestr

The type of device used for running the experiment. i.e. "cuda", "cpu", "cuda:0".

model_dirPath

The model directory.

experiment_dirPath

The experiment directory.

verbosebool

If True, prints additional information while training. Only applied for the master process.

ampbool

If True, apply automatic mixed precision training, otherwise default precision.

random_seedOptional[int]

Sets the seed for generating random numbers.

progress_barUnion[ProgressBar, Dummy]

An optional instance of ProgressBar that displays real-time information during training. e.g. time remaining. Only applied for the master process.

current_checkpointOptional[Path]

Directory for the current checkpoint file, by default None.

train_metricsMetrics

Training metrics including model information. i.e. learning rate and loss value.

eval_metricsMetrics | None

Evaluation metrics for when a val_dataloader is provided.

current_statedict

The currrent state of the model, including run_config, metrics and other necessary states.

learning_ratefloat

The current learning rate.

total_stepsint

The total steps for the training process.

epochsint

The total epochs for the training process.

current_iterationint

The current iteration of training.

best_iterationint

The iteration with the best loss value.

best_metricsdict[str, float]

The lowest optim values encountered during training.

optim_metric_namestr | None

The name of the optimization metric.

optim_metric_directionOptim | None

The optimization direction

abstract checkpoint(is_best: bool = False)[source]#

Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the checkpoint method in the ModelWrapper class.

Parameters:
is_bestbool

Indicates if the current checkpoint is the best model so far, by default False.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

abstract config_parser(run_config: RunConfig)[source]#

Abstract method to parse the provided configuration.

Must be implemented by subclasses. Example implementation: Please see the make_dataloaders method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

abstract create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#

Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the create_model method in the ModelWrapper class.

Parameters:
save_dictdict[str, ty.Any] | None

A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default None.

strict_loadbool

If True, the model will be loaded strictly, ensuring that the saved state matches the model’s structure exactly. If False, the model can be loaded with a partially matching state, by default True.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

property current_epoch: int#

Calculates and returns the current epoch during training.

Returns:
int

The current epoch number.

property epoch_len: int#

Returns the length of an epoch, which is the number of batches in the train_dataloader.

Returns:
int

The length of an epoch, represented as the number of batches in the train_dataloader.

Raises:
RuntimeError

If the train_dataloader is not defined or its length is 0.

property eval_itr: int#

Calculate the interval between evaluations.

Returns:
int

The interval between evaluations.

abstract evaluate(run_config: RunConfig)[source]#

Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the evaluate method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

abstract evaluation_functions() dict[str, collections.abc.Callable] | None[source]#

Abstract method to create and return a dictionary of evaluation functions used during training and validation.

Must be implemented by subclasses. Example implementation: Please see the evaluation_functions method in the ModelWrapper class.

Returns:
dict[str, Callable] | None

A dictionary containing evaluation functions as values and their names as keys.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

init_state(run_config: ~ablator.config.proto.RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False, remote_progress_bar: <ablator.utils.progress_bar.ActorClass(RemoteProgressBar) object at 0x7f0d3ce62290> | None = None, from_chkpt: str | ~pathlib.Path | None = None, data_lock: ~ablator.utils.base.Lock | None = None)[source]#

Initializes the state of the wrapper based on provided configuration and parameters. The lazy-initialization of the wrapper is neccessary, as the wrapper must be pickable. Initializing some objects (e.g. Dataloaders) leads to the opposite.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

smoke_testbool

Whether to run as a smoke test, by default False.

debugbool

If True, disables logging and model directory creation, by default False.

resumebool

If True, tries to resume training from a checkpoint, by default False.

remote_progress_barty.Optional[RemoteProgressBar]

A remote progress bar can be used to report metrics from the internal progress bar

from_chkpt: Path | str | None, optional

Path to the checkpoint to initialize the state from.

data_lock: ty.Optional[Lock], optional

Use a Lock to avoid downloading data concurrently.

Raises:
RuntimeError

if the state is already initialized and smoke_test, debug and resume flag are False

abstract load_checkpoint(save_dict: dict[str, Any], model_only: bool = False)[source]#

Abstract method to load the model and its state from a given save dictionary.

Must be implemented by subclasses. Example implementation: Please see the load_checkpoint method in the ModelWrapper class.

Parameters:
save_dictdict[str, ty.Any]

A dictionary containing the saved model state and other necessary information.

model_onlybool

If True, only the model’s weights will be loaded, ignoring other state information, by default False.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

property log_itr: int#

Calculate the interval between logging steps.

Returns:
int

The interval between logging steps.

abstract make_dataloaders(run_config: RunConfig)[source]#

Abstract method to create dataloaders for the training, validation, and testing datasets.

This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided run_config.

Must be implemented by subclasses. Example implementation: Please see the make_dataloaders method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

abstract save_dict() dict[str, Any] | None[source]#

Abstract method to create and return a save dictionary containing the model’s state and other necessary information.

Must be implemented by subclasses. Example implementation: Please see the save_dict method in the ModelWrapper class.

Returns:
dict[str, ty.Any] | None

A dictionary containing the saved model state and other necessary information.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

abstract train(run_config: RunConfig, smoke_test: bool = False)[source]#

Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the train method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

smoke_testbool

Whether to run as a smoke test, by default False.

Raises:
NotImplementedError

If this method is not implemented by the subclasses.

property train_stats: OrderedDict#

Returns an ordered dictionary containing the current training statistics.

Returns:
OrderedDict

An ordered dictionary with the following keys and values:

  • learning_rate: The current learning rate.

  • total_steps: The total steps for the training process.

  • epochs: The number of epochs for training.

  • current_epoch: The current epoch during training.

  • current_iteration: The current iteration during training.

  • best_iteration: The iteration with the best loss value so far.

  • best_*: The best (lowest) optim metric value achieved during training, for example best_val_loss

property uid: str#

Returns a unique identifier (UID) for the current run configuration.

Returns:
str

A string representing the unique identifier of the current run configuration.

exception ablator.main.model.main.TrainPlateauError[source]#

Bases: Exception

ablator.main.model.wrapper module#

class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]#

Bases: ModelBase

A wrapper around model_class that removes training boiler-plate code. Once make_dataloader_train is overridden to provide a training dataset, you can pass ModelWrapper object to the trainers (ProtoTrainer or ParallelTrainer) along with a running configuration (RunConfig or ParallelConfig) to launch the experiment.

The wrapper lets you customize the training process by allowing the overriding of various functions, this makes it adaptable to different training paradigms. Several customizing use cases are shown below.

Parameters:
model_classtype[nn.Module]

The model class to wrap.

Attributes:
model_classtorch.nn.Module

The model class to wrap.

modeltorch.nn.Module

The model created from the model class or checkpoint

optimizerOptimizer

The optimizer created from the optimizer config or checkpoint

scalerGradScaler

The scaler created from the scaler config or checkpoint

schedulerScheduler

The scheduler created from the scheduler config or checkpoint

apply_loss(model: Module, loss: Tensor | None, optimizer: Optimizer, scaler: GradScaler) float | None[source]#

Calculate the loss and apply the gradients, call optimizer.step() and scheduler.step().

Parameters:
modelnn.Module

The model to apply the loss to.

losstorch.Tensor | None

The loss to apply.

optimizerOptimizer

The optimizer to step.

scaler: torch.cuda.amp.GradScaler

The scaler to use during mixed precision training.

Returns:
float | None

The loss value.

final backward(loss: Tensor, scaler: GradScaler | None = None)[source]#
checkpoint(is_best: bool = False)[source]#

Save a checkpoint of the model.It will use the class name of the model as the filename.

Parameters:
is_bestbool

Whether this is the best model so far, by default False.

config_parser(run_config: RunConfig) RunConfig[source]#

You can overwrite this function to initialize Derived properties that are not decided until the experiment is launched.

Parameters:
run_configRunConfig

The run config for the experiment.

Returns:
RunConfig

Examples

For example, in GPT2 models, sometimes we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run an ablation study on different tokenizers:

>>> class MyLMWrapper(ModelWrapper):
...    def config_parser(self, run_config: RunConfig):
...        run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer)
...        return run_config
create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True)[source]#

Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from configuration objects. You can overwrite this function and save_dict() function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.

Parameters:
save_dictdict[str, ty.Any] | None

The saved checkpoint dictionary to load from, by default None.

strict_loadbool

Whether to throw an error for mismatched keys, by default True.

create_optimizer(model: Module, optimizer_config: OptimizerConfig | None = None, optimizer_state: dict[str, Any] | None = None) Optimizer[source]#

Creates the optimizer from the saved state or from config.

Parameters:
modelnn.Module

The model to create the optimizer for.

optimizer_configOptimizerConfig | None

The optimizer config to create the optimizer from, by default None.

optimizer_statedict[str, ty.Any] | None

The optimizer state to load the optimizer from, by default None.

Returns:
Optimizer

The optimizer.

create_scaler(scaler_state: dict[str, Any] | None = None) GradScaler[source]#

Creates the scaler from the saved state or from config.

Parameters:
scaler_statety.Optional[dict[str, ty.Any]]

The scaler state to load the scaler from, optional, by default None.

Returns:
GradScaler

The scaler.

create_scheduler(model: Module, optimizer: Optimizer, scheduler_config: SchedulerConfig | None = None, scheduler_state: dict[str, Any] | None = None) _LRScheduler | ReduceLROnPlateau | Any | None[source]#

Creates the scheduler from the saved state or from config.

Parameters:
modelnn.Module

The model to create the scheduler for.

optimizerOptimizer

The optimizer to create the scheduler for.

scheduler_configSchedulerConfig | None

The scheduler config to create the scheduler from, by default None.

scheduler_statedict[str, ty.Any] | None

The scheduler state to load the scheduler from, by default None.

Returns:
ty.Optional[Scheduler]

The scheduler.

final eval(smoke_test: bool = False)[source]#

Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step. It also check if it is early stopping (check Model Configuration module for more details).

Parameters:
smoke_testbool

If True, for a smoke test, by default False.

Raises:
LossDivergedError

If the loss is diverged during eval.

TrainPlateauError

If loss begins to slow down dramatically.

final evaluate(run_config: RunConfig, chkpt: str | Path | None = None) dict[str, dict[str, Any]][source]#

Evaluate the model after training on the test and validation sets.

Parameters:
run_configRunConfig

The run config to use for evaluation.

chkpt: str | Path | None

Path to the checkpoint to evaluate. If None, the latest checkpoint is evaluated, by default None.

Returns:
dict[str, dict[str, ty.Any]]

Metrics

evaluation_functions() dict[str, collections.abc.Callable] | None[source]#

You can overwrite this function and return the evaluation functions’ callables that will be used to evaluate experiment metrics.

Returns:
dict[str, Callable] | None

The evaluation functions to use.

Examples

  • Define the callables:

>>> def my_accuracy(y_true, y_pred):
>>>     return accuracy_score(y_true.flatten(), y_pred.flatten())
>>> def my_f1_score(y_true, y_pred):
>>>     return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
  • Returns callables in evaluation_functions:

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
>>>     def evaluation_functions(self):
>>>         return {
...             "accuracy": my_accuracy,
...             "f1": my_f1_score
...         }
  • Note that the callable’s parameter names must match the model’s forward output. In our example, y_true and y_pred must be returned by the model’s forward method to match with y_true and y_pred parameters of my_accuracy and my_f1_score functions:

>>> class MyModel(nn.Module):
>>>     def __init__(self, config: CustomModelConfig) -> None:
>>>         super().__init__()
>>>         self.model = FashionMNISTModel(config)
>>>         self.loss = nn.CrossEntropyLoss()
>>>     def forward(self, x, labels=None):
>>>         out = self.model(x)
>>>         loss = None
>>>         if labels is not None:
>>>             loss = self.loss(out, labels)
>>>         out = out.argmax(dim=-1)
>>>         return {"y_pred": out, "y_true": labels}, loss
load_checkpoint(save_dict: dict[str, Any], model_only: bool = False)[source]#

Loads the checkpoint from the save dict.

Parameters:
save_dictdict[str, ty.Any]

The save dict to load the checkpoint from.

model_onlybool

Whether to load only the model or include scheduler, optimizer and scaler, by default False.

Notes

This method is the implementation of the abstract method in the base class.

log()[source]#

Log if the current iteration is a logging step. It also evaluate training metrics for logging.

log_step()[source]#

A single step for logging.

Notes

This method is update the logger with the current metrics and log a status message.

make_dataloader_test(run_config: RunConfig) DataLoader | None[source]#

Function to make the test dataloader. You can overwrite this function and return the test dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The test dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_test(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]#

Function to make the training dataloader. You must overwrite this function and return the training dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader

The training dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
make_dataloader_val(run_config: RunConfig) DataLoader | None[source]#

Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.

Parameters:
run_configRunConfig

The run configuration.

Returns:
DataLoader | None

The validation dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
make_dataloaders(run_config: RunConfig) None[source]#

This function is done post-initialization because otherwise the dataloaders are pickled with the object when running distributed.

Parameters:
run_configRunConfig

The run config for the experiment.

property metrics: dict[str, float]#

The metrics of the current training state. If eval_metrics are defined (e.g. a validation dataloader was provided) then they are combined and a val_ and train_ label is prepended. The current training statistics are also combined

Returns:
dict[str, float]

A dictionary with keys the metric name and the value corresponding to the metric.

property model_config: ModelConfig#
model_step(model: Module, batch: Iterable) tuple[dict[str, torch.Tensor] | None, torch.Tensor | None][source]#

A single inference step for the model.

Parameters:
modelnn.Module

The model to train.

batchIterable

The batch of input data to pass through the model,it could be a list, dict or a single tensor.

Returns:
tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]

The output of the model,contains current predictions and loss of the model

final optim_step(optimizer: Optimizer, scaler: GradScaler, model: Module, loss: Tensor | None)[source]#
save_dict() dict[str, Any][source]#

Save the current state of the trainer, including model parameters, and the current states of the optimizer, scaler, and scheduler. You can overwrite this function and create_model() to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.

Returns:
dict[str, ty.Any]

The current state of the trainer.

  • model: the model’s state

  • optimizer: the state of optimizer

  • scheduler: the scheduler’s state

  • scaler: the state of gradScaler.

final scheduler_step(*args, is_val_step=False)[source]#

A single scheduler step. This function accounts for when the scheduler is supposed to take a step. It reads step_when as a property from the scheduler or derives it from the train configuration. step_when is expected to be in {“train”, “epoch”, “val”}. When the scheduler is a validation scheduler, it expects some mertrics passed as arguments to the scheduler step function. e.g. wrapper.scheduler_step(0.001).

When no args is provided and it is a validation scheduler, this function is no-op.

status_message() str[source]#

Return a string generated from dictionary of current metrics, including all the static metrics and moving average metrics.

Returns:
str

The status message.

to_device(data: Iterable, device: str | None = None) Iterable[source]#

Moves the data to the specified device.

Parameters:
dataIterable

The data to move to the device.

devicety.Optional[str]

The device to move the data to. If None, the device specified in the config is used.

Returns:
Iterable

The data on the device.

property total_steps: int#

The total number of steps for training.

Returns:
int

total steps = epoch’s length * number of epochs.

final train(run_config: RunConfig | None = None, smoke_test: bool = False, debug: bool = False, resume: bool = False) dict[str, float][source]#

Initialize states and train the model. When keyboard interrupts, saves a checkpoint

Parameters:
run_configRunConfig | None

The run config to use for training, by default None.

smoke_testbool

Whether to run a smoke test, by default False.

debugbool

Whether to run in debug mode, by default False.

resumebool

Whether to resume training the model from existing checkpoints and existing experiment state, by default False.

Returns
——-
dict[str, float]

The metrics from the training.

Raises:
ValueError

if the state is not initialized and no run_config is provided or when a run_config is provided but the state is already initialized.

property train_config: TrainConfig#
train_loop(smoke_test: bool = False) dict[str, float][source]#

Train the model in many steps, evaluate the model and log the metrics for each iteration. metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.

Parameters:
smoke_testbool

Whether to run a smoke test.

Returns:
dict[str, float]

metrics after completion of training.

Raises:
ValueError

If model outputs are not stored in correct format.

LossDivergedError

If the loss diverged.

final train_step(batch: Iterable) tuple[dict[str, torch.Tensor] | None, dict[str, Any]][source]#

A single step for training. It also updates learning rate with scheduler.

Parameters:
batchIterable

The batch of input data to pass through the model,it could be a list, dict or a single tensor.

Returns:
tuple[dict[str, torch.Tensor] | None, dict[str, ty.Any]]
  • outputsdict[str, torch.Tensor] | None

    The output of the model.

  • train_metrics: dict[str, ty.Any]

    The training metrics.

update_status() None[source]#

Update the metrics with current training stats, and then all metrics (static and moving average) will be set as description for the tqdm progress.

validation_loop(model: Module, dataloader: DataLoader, metrics: Metrics, subsample: float = 1.0, smoke_test: bool = False) dict[str, float][source]#

Validate the model on data on dataloader and store results on metrics.

Parameters:
modelnn.Module

The model to validate.

dataloaderDataLoader

The dataloader to use for validation.

metricsMetrics

The metrics to use for validation.

subsamplefloat

The fraction of the dataloader to use for validation, by default 1.0.

smoke_testbool

Whether to execute this function as a smoke test. If True, only one iteration will be performed, which is useful for quickly checking if the code runs without errors, by default False.

Returns:
dict[str, float]

The metrics from the validation.

Module contents#