Model Training Interface#

class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]

Bases: ModelBase

A wrapper around model_class that removes training boiler-plate code. Once make_dataloader_train is overridden to provide a training dataset, you can pass ModelWrapper object to the trainers (ProtoTrainer or ParallelTrainer) along with a running configuration (RunConfig or ParallelConfig) to launch the experiment.

The wrapper lets you customize the training process by allowing the overriding of various functions, this makes it adaptable to different training paradigms. Several customizing use cases are shown below.

Parameters:
model_classtype[nn.Module]

The model class to wrap.

Attributes:
model_classtorch.nn.Module

The model class to wrap.

modeltorch.nn.Module

The model created from the model class or checkpoint

optimizerOptimizer

The optimizer created from the optimizer config or checkpoint

scalerGradScaler

The scaler created from the scaler config or checkpoint

schedulerScheduler

The scheduler created from the scheduler config or checkpoint

config_parser(run_config: RunConfig) RunConfig[source]

You can overwrite this function to initialize Derived properties that are not decided until the experiment is launched.

Parameters:
run_configRunConfig

The run config for the experiment.

Returns:
RunConfig

Examples

For example, in GPT2 models, sometimes we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run an ablation study on different tokenizers:

>>> class MyLMWrapper(ModelWrapper):
...    def config_parser(self, run_config: RunConfig):
...        run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer)
...        return run_config
create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True)[source]

Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from configuration objects. You can overwrite this function and save_dict() function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.

Parameters:
save_dictdict[str, ty.Any] | None

The saved checkpoint dictionary to load from, by default None.

strict_loadbool

Whether to throw an error for mismatched keys, by default True.

evaluation_functions() dict[str, collections.abc.Callable] | None[source]

You can overwrite this function and return the evaluation functions’ callables that will be used to evaluate experiment metrics.

Returns:
dict[str, Callable] | None

The evaluation functions to use.

Examples

  • Define the callables:

>>> def my_accuracy(y_true, y_pred):
>>>     return accuracy_score(y_true.flatten(), y_pred.flatten())
>>> def my_f1_score(y_true, y_pred):
>>>     return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
  • Returns callables in evaluation_functions:

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
>>>     def evaluation_functions(self):
>>>         return {
...             "accuracy": my_accuracy,
...             "f1": my_f1_score
...         }
  • Note that the callable’s parameter names must match the model’s forward output. In our example, y_true and y_pred must be returned by the model’s forward method to match with y_true and y_pred parameters of my_accuracy and my_f1_score functions:

>>> class MyModel(nn.Module):
>>>     def __init__(self, config: CustomModelConfig) -> None:
>>>         super().__init__()
>>>         self.model = FashionMNISTModel(config)
>>>         self.loss = nn.CrossEntropyLoss()
>>>     def forward(self, x, labels=None):
>>>         out = self.model(x)
>>>         loss = None
>>>         if labels is not None:
>>>             loss = self.loss(out, labels)
>>>         out = out.argmax(dim=-1)
>>>         return {"y_pred": out, "y_true": labels}, loss
make_dataloader_test(run_config: RunConfig) DataLoader | None[source]

Function to make the test dataloader. You can overwrite this function and return the test dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The test dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_test(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]

Function to make the training dataloader. You must overwrite this function and return the training dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader

The training dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
make_dataloader_val(run_config: RunConfig) DataLoader | None[source]

Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.

Parameters:
run_configRunConfig

The run configuration.

Returns:
DataLoader | None

The validation dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
save_dict() dict[str, Any][source]

Save the current state of the trainer, including model parameters, and the current states of the optimizer, scaler, and scheduler. You can overwrite this function and create_model() to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.

Returns:
dict[str, ty.Any]

The current state of the trainer.

  • model: the model’s state

  • optimizer: the state of optimizer

  • scheduler: the scheduler’s state

  • scaler: the state of gradScaler.