Model Training Interface#
- class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]
Bases:
ModelBaseA wrapper around
model_classthat removes training boiler-plate code. Its functions are over-writable to support for custom use-cases. Oncemake_dataloader_trainis overriden to provide a training dataset, you can passModelWrapperobject to the trainers (ProtoTrainerorParallelTrainer) along with a running configuration (RunConfigorParallelConfig) to launch the experiment.- Attributes:
- model_class: torch.nn.Module
The model class to wrap.
- model: torch.nn.Module
The model created from the model class or checkpoint
- optimizer: Optimizer
The optimizer created from the optimizer config or checkpoint
- scaler: GradScaler
The scaler created from the scaler config or checkpoint
- scheduler: Scheduler
The scheduler created from the scheduler config or checkpoint
- config_parser(run_config: RunConfig)[source]
You can overwrite this function to initialize
Derivedproperties that are not decided until the experiment is launched.Examples
For example, in GPT2 model, we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run ablation study on different tokenizers:
>>> class MyLMWrapper(ModelWrapper): ... def config_parser(self, run_config: RunConfig): ... run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer) ... return run_config
- create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]
Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from config. You can overwrite this function and
save_dict()function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.- Parameters:
- save_dict: dict[str, ty.Any]
The saved checkpoint dictionary to load from.
- strict_load: bool
Whether to load the model strictly or not.
- evaluation_functions() dict[str, collections.abc.Callable] | None[source]
You can overwrite this function and return the evaluation functions callables that will be used to evaluate experiment metrics.
- Returns:
- dict[str, Callable]
The evaluation functions to use. Also see
Metricsfor details.
Examples
Define the callables:
>>> def my_accuracy(y_true, y_pred): >>> return accuracy_score(y_true.flatten(), y_pred.flatten()) >>> def my_f1_score(y_true, y_pred): >>> return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
Returns callables in
evaluation_functions:
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... ) >>> def evaluation_functions(self): >>> return { ... "accuracy": my_accuracy, ... "f1": my_f1_score ... }
Note that the callable’s parameter names must match the model’s forward output. In our example,
y_trueandy_predmust be returned by the model’s forward method to match withy_trueandy_predparameters ofmy_accuracyandmy_f1_scorefunctions:
>>> class MyModel(nn.Module): >>> def __init__(self, config: CustomModelConfig) -> None: >>> super().__init__() >>> self.model = FashionMNISTModel(config) >>> self.loss = nn.CrossEntropyLoss() >>> def forward(self, x, labels=None): >>> out = self.model(x) >>> loss = None >>> if labels is not None: >>> loss = self.loss(out, labels) >>> out = out.argmax(dim=-1) >>> return {"y_pred": out, "y_true": labels}, loss
- make_dataloader_test(run_config: RunConfig) DataLoader | None[source]
Function to make the test dataloader. You can overwrite this function and return the test dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The test dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_test(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]
Function to make the training dataloader. You must overwrite this function and return the training dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader
The training dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloader_val(run_config: RunConfig) DataLoader | None[source]
Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The validation dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... )
- save_dict() dict[str, Any][source]
Save the current state of the trainer, including model parameters, the current states of the optimizer, scaler, and scheduler. You can overwrite this function and
create_model()to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.- Returns:
- dict[str, ty.Any]
The current state of the trainer.