Model Training Interface#

class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]

Bases: ModelBase

A wrapper around model_class that removes training boiler-plate code. Its functions are over-writable to support for custom use-cases. Once make_dataloader_train is overriden to provide a training dataset, you can pass ModelWrapper object to the trainers (ProtoTrainer or ParallelTrainer) along with a running configuration (RunConfig or ParallelConfig) to launch the experiment.

Attributes:
model_class: torch.nn.Module

The model class to wrap.

model: torch.nn.Module

The model created from the model class or checkpoint

optimizer: Optimizer

The optimizer created from the optimizer config or checkpoint

scaler: GradScaler

The scaler created from the scaler config or checkpoint

scheduler: Scheduler

The scheduler created from the scheduler config or checkpoint

config_parser(run_config: RunConfig)[source]

You can overwrite this function to initialize Derived properties that are not decided until the experiment is launched.

Examples

For example, in GPT2 model, we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run ablation study on different tokenizers:

>>> class MyLMWrapper(ModelWrapper):
...    def config_parser(self, run_config: RunConfig):
...        run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer)
...        return run_config
create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]

Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from config. You can overwrite this function and save_dict() function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.

Parameters:
save_dict: dict[str, ty.Any]

The saved checkpoint dictionary to load from.

strict_load: bool

Whether to load the model strictly or not.

evaluation_functions() dict[str, collections.abc.Callable] | None[source]

You can overwrite this function and return the evaluation functions callables that will be used to evaluate experiment metrics.

Returns:
dict[str, Callable]

The evaluation functions to use. Also see Metrics for details.

Examples

  • Define the callables:

>>> def my_accuracy(y_true, y_pred):
>>>     return accuracy_score(y_true.flatten(), y_pred.flatten())
>>> def my_f1_score(y_true, y_pred):
>>>     return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
  • Returns callables in evaluation_functions:

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
>>>     def evaluation_functions(self):
>>>         return {
...             "accuracy": my_accuracy,
...             "f1": my_f1_score
...         }
  • Note that the callable’s parameter names must match the model’s forward output. In our example, y_true and y_pred must be returned by the model’s forward method to match with y_true and y_pred parameters of my_accuracy and my_f1_score functions:

>>> class MyModel(nn.Module):
>>>     def __init__(self, config: CustomModelConfig) -> None:
>>>         super().__init__()
>>>         self.model = FashionMNISTModel(config)
>>>         self.loss = nn.CrossEntropyLoss()
>>>     def forward(self, x, labels=None):
>>>         out = self.model(x)
>>>         loss = None
>>>         if labels is not None:
>>>             loss = self.loss(out, labels)
>>>         out = out.argmax(dim=-1)
>>>         return {"y_pred": out, "y_true": labels}, loss
make_dataloader_test(run_config: RunConfig) DataLoader | None[source]

Function to make the test dataloader. You can overwrite this function and return the test dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The test dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_test(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]

Function to make the training dataloader. You must overwrite this function and return the training dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader

The training dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             test_dataset,
...             batch_size=32,
...             shuffle=False
...         )
make_dataloader_val(run_config: RunConfig) DataLoader | None[source]

Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The validation dataloader.

Examples

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             train_dataset,
...             batch_size=32,
...             shuffle=True
...         )
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(
...             val_dataset,
...             batch_size=32,
...             shuffle=False
...         )
save_dict() dict[str, Any][source]

Save the current state of the trainer, including model parameters, the current states of the optimizer, scaler, and scheduler. You can overwrite this function and create_model() to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.

Returns:
dict[str, ty.Any]

The current state of the trainer.