Model Training Interface#
- class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]
Bases:
ModelBaseA wrapper around
model_classthat removes training boiler-plate code. Oncemake_dataloader_trainis overridden to provide a training dataset, you can passModelWrapperobject to the trainers (ProtoTrainerorParallelTrainer) along with a running configuration (RunConfigorParallelConfig) to launch the experiment.The wrapper lets you customize the training process by allowing the overriding of various functions, this makes it adaptable to different training paradigms. Several customizing use cases are shown below.
- Parameters:
- model_classtype[nn.Module]
The model class to wrap.
- Attributes:
- model_classtorch.nn.Module
The model class to wrap.
- modeltorch.nn.Module
The model created from the model class or checkpoint
- optimizerOptimizer
The optimizer created from the optimizer config or checkpoint
- scalerGradScaler
The scaler created from the scaler config or checkpoint
- schedulerScheduler
The scheduler created from the scheduler config or checkpoint
- config_parser(run_config: RunConfig) RunConfig[source]
You can overwrite this function to initialize
Derivedproperties that are not decided until the experiment is launched.- Parameters:
- run_configRunConfig
The run config for the experiment.
- Returns:
- RunConfig
Examples
For example, in GPT2 models, sometimes we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run an ablation study on different tokenizers:
>>> class MyLMWrapper(ModelWrapper): ... def config_parser(self, run_config: RunConfig): ... run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer) ... return run_config
- create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True)[source]
Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from configuration objects. You can overwrite this function and
save_dict()function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.- Parameters:
- save_dictdict[str, ty.Any] | None
The saved checkpoint dictionary to load from, by default
None.- strict_loadbool
Whether to throw an error for mismatched keys, by default
True.
- evaluation_functions() dict[str, collections.abc.Callable] | None[source]
You can overwrite this function and return the evaluation functions’ callables that will be used to evaluate experiment metrics.
- Returns:
- dict[str, Callable] | None
The evaluation functions to use.
Examples
Define the callables:
>>> def my_accuracy(y_true, y_pred): >>> return accuracy_score(y_true.flatten(), y_pred.flatten()) >>> def my_f1_score(y_true, y_pred): >>> return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
Returns callables in
evaluation_functions:
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... ) >>> def evaluation_functions(self): >>> return { ... "accuracy": my_accuracy, ... "f1": my_f1_score ... }
Note that the callable’s parameter names must match the model’s forward output. In our example,
y_trueandy_predmust be returned by the model’s forward method to match withy_trueandy_predparameters ofmy_accuracyandmy_f1_scorefunctions:
>>> class MyModel(nn.Module): >>> def __init__(self, config: CustomModelConfig) -> None: >>> super().__init__() >>> self.model = FashionMNISTModel(config) >>> self.loss = nn.CrossEntropyLoss() >>> def forward(self, x, labels=None): >>> out = self.model(x) >>> loss = None >>> if labels is not None: >>> loss = self.loss(out, labels) >>> out = out.argmax(dim=-1) >>> return {"y_pred": out, "y_true": labels}, loss
- make_dataloader_test(run_config: RunConfig) DataLoader | None[source]
Function to make the test dataloader. You can overwrite this function and return the test dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The test dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_test(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]
Function to make the training dataloader. You must overwrite this function and return the training dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader
The training dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloader_val(run_config: RunConfig) DataLoader | None[source]
Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.
- Parameters:
- run_configRunConfig
The run configuration.
- Returns:
- DataLoader | None
The validation dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... )
- save_dict() dict[str, Any][source]
Save the current state of the trainer, including model parameters, and the current states of the optimizer, scaler, and scheduler. You can overwrite this function and
create_model()to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.- Returns:
- dict[str, ty.Any]
The current state of the trainer.
model: the model’s state
optimizer: the state of optimizer
scheduler: the scheduler’s state
scaler: the state of gradScaler.