Saving and loading multi-module models#
Ablator is a flexible framework, you can overwrite its functions to extend to your use case. In this tutorial, we will show how ablator can be customized so that we can save and load multi-module models. Saving multi-module models is helpful when you have a model that consists of multiple modules, and you want to save the entire model to a file and load it back later on. Sample use cases include encoder and decoder blocks in a transformer model, ensemble models, etc.
For demonstration purpose, in this tutorial, we will create an ensemble of 3 simple 1-hidden layer neural networks, train them on the breast cancer dataset for 30 epochs, save the ensemble as a 3-module model, load it back and train for another 30 epochs.
Let us first import necessary modules
from ablator import ModelConfig, OptimizerConfig, TrainConfig, ParallelConfig
from ablator import ModelWrapper, ParallelTrainer
from ablator.main.configs import SearchSpace
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import f1_score, accuracy_score # for custom evaluation functions
import shutil
import os
Preparing the data#
class BreastCancerDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.scaler = MinMaxScaler()
self.data = self.scaler.fit_transform(self.data)
self.targets = targets
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
def __len__(self):
return len(self.data)
# Load dataset from scikit-learn
breast_cancer = load_breast_cancer()
data = breast_cancer.data
targets = breast_cancer.target
# Split the data into train and test sets
train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.2, random_state=42)
# Create train and test datasets
train_dataset = BreastCancerDataset(train_data, train_targets)
test_dataset = BreastCancerDataset(test_data, test_targets)
Build the ensemble model#
Assemble the modules#
We create the ensemble model named MyEnsemble. It consists of 3 separate neural networks, the final prediction probability is calculated by simply aggregating outputs from the 3 networks, and applying softmax to it.
class MyEnsemble(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.nnet1 = NNet()
self.nnet2 = NNet()
self.nnet3 = NNet()
def forward(self, x, labels=None):
x1 = self.nnet1(x)
x2 = self.nnet2(x)
x3 = self.nnet3(x)
ensemble = x1+x2+x3
ensemble = F.softmax(ensemble, dim=1)
loss = F.cross_entropy(ensemble, labels)
preds = torch.argmax(ensemble, dim=1)
preds = preds.reshape(-1, 1)
labels = labels.reshape(-1, 1)
return {"preds": preds, "labels": labels}, loss
Configure the experiment#
Now it’s time we set up the ablation experiment by defining configurations.
Model configuration#
Since we’re not ablating the model architecture, no custom model configuration is needed, so model configuration is just an empty one.
model_config = ModelConfig()
Optimizer configuration#
We will use adam optimizer, with the learning rate initialized to 0.01
optimizer_config = OptimizerConfig(
name="adam",
arguments={"lr": 0.001}
)
We will also define a search space for different learning rate values:
search_space = {
"train_config.optimizer_config.arguments.lr": SearchSpace(
value_range = [0.001, 0.01],
value_type = 'float'
)
}
Training configuration#
train_config = TrainConfig(
dataset="breast-cancer",
batch_size=32,
epochs=30,
optimizer_config=optimizer_config,
scheduler_config=None,
rand_weights_init = True
)
Running configuration (parallel config)#
Combine model configuration, train configuration, and search space into the running configuration:
run_config = ParallelConfig(
train_config=train_config,
model_config=model_config,
metrics_n_batches = 800,
experiment_dir = "/tmp/experiments/",
device="cuda",
amp=True,
random_seed = 42,
total_trials = 5,
concurrent_trials = 3,
search_space = search_space,
optim_metrics = {"val_loss": "min"},
gpu_mb_per_experiment = 1024
)
Model wrapper#
Other than overwriting the data loaders functions and evaluation functions, we modify the model saving and loading functions of ModelWrapper so that ablator saves the model as a 3-module model:
{
"model": {
"nnet1": self.model.nnet1.state_dict(),
"nnet2": self.model.nnet2.state_dict(),
"nnet3": self.model.nnet3.state_dict()
}
}
Multi-module model saving#
We will overwrite ModelWrapper.save_dict() function to save the entire model as a dictionary of modules.
def save_dict(self):
saved_dict = super().save_dict()
model_state_dict = {
"nnet1": self.model.nnet1.state_dict(),
"nnet2": self.model.nnet2.state_dict(),
"nnet3": self.model.nnet3.state_dict(),
}
saved_dict["model"] = model_state_dict
return saved_dict
Originally, ablator framework saves the model as a whole, i.e., saved_dict["model"] = self.model.state_dict().
In our example, as you can see, modules nnet1, nnet2, and nnet3 from MyEnsemble can be accessed via self.model.nnet1, self.model.nnet2, and self.model.nnet3 respectively, and we will save these modules’ state dictionaries into saved_dict["model"].
This way, the model saved will be a dictionary of modules:
saved_dict = {
"model": {
"nnet1": {"fc1.weights": weights, "fc1.bias": bias},
"nnet2": {"fc2.weights": weights, "fc2.bias": bias},
"nnet3": {"fc3.weights": weights, "fc3.bias": bias},
},
...
}
After running the experiment, you can use torch.load(<path_to_checkpoint>) to verify this, where <path_to_checkpoint> is the path to one of the models that are saved in the experiment directory.
Multi-module model loading#
Now that we have saved a multi-module model, we also need to change how ablator loads the model. We do this by overwriting ModelWrapper.create_model() function.
def create_model(
self,
save_dict: dict[str, ty.Any] | None = None,
strict_load: bool = True,
) -> None:
if save_dict is not None:
nd_save_dict = {}
for nnet in save_dict["model"]:
for key in save_dict["model"][nnet]:
new_key = nnet + "." + key
nd_save_dict[new_key] = save_dict["model"][nnet][key]
save_dict["model"] = nd_save_dict
super().create_model(save_dict=save_dict, strict_load=True)
Originally, ablator framework will load the model as a whole, i.e., model.load_state_dict(save_dict["model"]).
So in our example, as you can see, the keys will be updated to nnet1.fc1.weights, nnet1.fc1.bias, nnet2.fc2.weights, nnet2.fc2.bias, nnet3.fc3.weights, and nnet3.fc3.bias. So that when we use super call, the model will be loaded correctly.
Below is the complete script for the model wrapper, where we provide the datasets via make_dataloaders functions and add the multi-module saving and loading code discussed above:
class MyEnsembleWrapper(ModelWrapper):
def make_dataloader_train(self, run_config: ParallelConfig):
return DataLoader(train_dataset, batch_size=run_config.train_config.batch_size, shuffle=True)
def make_dataloader_val(self, run_config: ParallelConfig):
return DataLoader(test_dataset, batch_size=run_config.train_config.batch_size, shuffle=False)
def save_dict(self):
saved_dict = super().save_dict()
model_state_dict = {
"nnet1": self.model.nnet1.state_dict(),
"nnet2": self.model.nnet2.state_dict(),
"nnet3": self.model.nnet3.state_dict(),
}
saved_dict["model"] = model_state_dict
return saved_dict
def create_model(self, save_dict=None, strict_load=True):
if save_dict is not None:
nd_save_dict = {}
for nnet in save_dict["model"]:
for key in save_dict["model"][nnet]:
new_key = nnet + "." + key
nd_save_dict[new_key] = save_dict["model"][nnet][key]
save_dict["model"] = nd_save_dict
super().create_model(save_dict=save_dict, strict_load=True)
Custom evaluation (Optional)#
We will use accuracy and f1 as evaluation metrics
def my_accuracy(preds, labels):
return accuracy_score(preds.flatten(), labels.flatten())
def my_f1_score(preds, labels):
return f1_score(preds.flatten(), labels.flatten(), average='weighted')
Launch the ablation experiment#
Everything is ready, now we can launch the ablation experiment.
shutil.rmtree(run_config.experiment_dir, ignore_errors=True)
wrapper = MyEnsembleWrapper(
model_class=MyEnsemble,
)
ablator = ParallelTrainer(
wrapper=wrapper,
run_config=run_config,
)
ablator.launch(working_directory = os.getcwd(), ray_head_address="auto")
When the experiment finishes, you should see the model checkpoints in each of the trial output folder, and you can verify the multi-module model state structure by loading it with torch: torch.load(<path_to_checkpoint>)
You can also rerun the ablation experiment using these checkpoints, by specifying init_chkpt parameter in the running config to load the model from the checkpoint we saved earlier. Remember to store the checkpoint somewhere else different than the experiment directory that you’re using for this rerun.
run_config = ParallelConfig(
train_config=train_config,
model_config=model_config,
metrics_n_batches = 800,
experiment_dir = "/tmp/experiments/",
device="cuda",
amp=True,
random_seed = 42,
total_trials = 5,
concurrent_trials = 3,
search_space = search_space,
optim_metrics = {"val_loss": "min"},
gpu_mb_per_experiment = 1024,
init_chkpt="/tmp/experiments1/experiment_7ae3_9991/2ca5_9991/best_checkpoints/MyEnsemble_0000000210.pt"
)
shutil.rmtree(run_config.experiment_dir, ignore_errors=True)
wrapper = MyEnsembleWrapper(
model_class=MyEnsemble,
)
ablator = ParallelTrainer(
wrapper=wrapper,
run_config=run_config,
)
metrics = ablator.launch(working_directory = os.getcwd(), ray_head_address=None)
The experiment will load the model from the checkpoint and continue training. And that’s it, this is an example that shows how customizable ablator is, so that you can customize it to fit your needs.
Conclusion#
In this tutorial, we have shown you the flexibility of ablator framework. With some understanding of the model wrapper, you will be able to upgrade the framework to your use cases.