Saving and loading multi-module models#

Open In Colab

Ablator is a flexible framework, you can overwrite its functions to extend to your use case. In this tutorial, we will show how ablator can be customized so that we can save and load multi-module models. Saving multi-module models is helpful when you have a model that consists of multiple modules, and you want to save the entire model to a file and load it back later on. Sample use cases include encoder and decoder blocks in a transformer model, ensemble models, etc.

For demonstration purpose, in this tutorial, we will create an ensemble of 3 simple 1-hidden layer neural networks, train them on the breast cancer dataset for 30 epochs, save the ensemble as a 3-module model, load it back and train for another 30 epochs.

Let us first import necessary modules

[ ]:
try:
    import ablator
except:
    !pip install ablator
    print("Stopping RUNTIME! Please run again") # This script automatically restart runtime (if ablator is not found and installing is needed) so changes are applied
    import os

    os.kill(os.getpid(), 9)
[2]:
from ablator import ModelConfig, OptimizerConfig, TrainConfig, ParallelConfig
from ablator.config.hpo import SearchSpace

import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

Preparing the data#

[3]:
class BreastCancerDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.scaler = MinMaxScaler()
        self.data = self.scaler.fit_transform(self.data)
        self.targets = targets

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.data)

# Load dataset from scikit-learn
breast_cancer = load_breast_cancer()
data = breast_cancer.data
targets = breast_cancer.target

# Split the data into train and test sets
train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.2, random_state=42)

# Create train and test datasets
train_dataset = BreastCancerDataset(train_data, train_targets)
test_dataset = BreastCancerDataset(test_data, test_targets)

Build the ensemble model#

Simple 1-hidden layer neural network module#

We create a simple NN module with a hidden layer of size 50, ReLu activation function is applied at the hidden layer. The output layer size is two, corresponding to the two classes of the dataset.

[4]:
class NNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(30, 50)
        self.fc2 = nn.Linear(50, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x

Assemble the modules#

We create the ensemble model named MyEnsemble. It consists of 3 separate neural networks, the final prediction probability is calculated by simply aggregating outputs from the 3 networks, and applying softmax to it.

[5]:
class MyEnsemble(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        self.nnet1 = NNet()
        self.nnet2 = NNet()
        self.nnet3 = NNet()

    def forward(self, x, labels=None):
        x1 = self.nnet1(x)
        x2 = self.nnet2(x)
        x3 = self.nnet3(x)

        ensemble = x1+x2+x3
        ensemble = F.softmax(ensemble, dim=1)

        loss = F.cross_entropy(ensemble, labels)
        preds = torch.argmax(ensemble, dim=1)

        preds = preds.reshape(-1, 1)
        labels = labels.reshape(-1, 1)

        return {"preds": preds, "labels": labels}, loss

Configure the experiment#

Now it’s time we set up the ablation experiment by defining configurations.

Model configuration#

Since we’re not ablating the model architecture, no custom model configuration is needed, so model configuration is just an empty one.

[6]:
model_config = ModelConfig()

Optimizer configuration#

We will use adam optimizer, with the learning rate initialized to 0.01

[8]:
optimizer_config = OptimizerConfig(
    name="adam",
    arguments={"lr": 0.001}
)
optimizer_config
[8]:
OptimizerConfig(name='adam', arguments={'betas': (0.9, 0.999), 'weight_decay': 0.0, 'lr': 0.001})

We will also define a search space for different learning rate values:

[9]:
search_space = {
    "train_config.optimizer_config.arguments.lr": SearchSpace(
        value_range = [0.001, 0.01],
        value_type = 'float'
    )
}

Training configuration#

[10]:
train_config = TrainConfig(
    dataset="breast-cancer",
    batch_size=32,
    epochs=30,
    optimizer_config=optimizer_config,
    scheduler_config=None
)
train_config
[10]:
TrainConfig(dataset='breast-cancer', batch_size=32, epochs=30, optimizer_config={'name': 'adam', 'arguments': {'betas': (0.9, 0.999), 'weight_decay': 0.0, 'lr': 0.001}}, scheduler_config=None)

Running configuration (parallel config)#

Combine model configuration, train configuration, and search space into the running configuration:

[11]:
run_config = ParallelConfig(
    train_config=train_config,
    model_config=model_config,
    metrics_n_batches = 300,
    experiment_dir = "/tmp/experiments/",
    device="cuda",
    amp=True,
    random_seed = 42,
    total_trials = 5,
    concurrent_trials = 2,
    search_space = search_space,
    optim_metrics = {"val_loss": "min"},
    optim_metric_name = "val_loss",
    gpu_mb_per_experiment = 512
)

Model wrapper#

Other than overwriting the data loaders functions and evaluation functions, we modify the model saving and loading functions of ModelWrapper so that ablator saves the model as a 3-module model:

{
    "model": {
        "nnet1": self.model.nnet1.state_dict(),
        "nnet2": self.model.nnet2.state_dict(),
        "nnet3": self.model.nnet3.state_dict()
    }
}

Multi-module model saving#

We will overwrite ModelWrapper.save_dict() function to save the entire model as a dictionary of modules, the function looks like this:

def save_dict(self):
    saved_dict = super().save_dict()
    model_state_dict = {
        "nnet1": self.model.nnet1.state_dict(),
        "nnet2": self.model.nnet2.state_dict(),
        "nnet3": self.model.nnet3.state_dict(),
    }
    saved_dict["model"] = model_state_dict

    return saved_dict

Originally, ablator framework saves the model as a whole, i.e., saved_dict["model"] = self.model.state_dict().

In our example, as you can see, modules nnet1, nnet2, and nnet3 from MyEnsemble can be accessed via self.model.nnet1, self.model.nnet2, and self.model.nnet3 respectively, and we will save these modules’ state dictionaries into saved_dict["model"].

This way, the model saved will be a dictionary of modules:

saved_dict = {
    "model": {
        "nnet1": {"fc1.weights": weights, "fc1.bias": bias},
        "nnet2": {"fc2.weights": weights, "fc2.bias": bias},
        "nnet3": {"fc3.weights": weights, "fc3.bias": bias},
    },
    ...
}

After running the experiment, you can use torch.load(<path_to_checkpoint>) to verify this, where <path_to_checkpoint> is the path to one of the models that are saved in the experiment directory.

Multi-module model loading#

Now that we have saved a multi-module model, we also need to change how ablator loads the model. We do this by overwriting ModelWrapper.create_model() function.

def create_model(
    self,
    save_dict: dict[str, ty.Any] | None = None,
    strict_load: bool = True,
) -> None:
    if save_dict is not None:
        nd_save_dict = {}
        for nnet in save_dict["model"]:
            for key in save_dict["model"][nnet]:
                new_key = nnet + "." + key
                nd_save_dict[new_key] = save_dict["model"][nnet][key]
        save_dict["model"] = nd_save_dict
    super().create_model(save_dict=save_dict, strict_load=True)

Originally, ablator framework will load the model as a whole, i.e., model.load_state_dict(save_dict["model"]).

So in our example, as you can see, the keys will be updated to nnet1.fc1.weights, nnet1.fc1.bias, nnet2.fc2.weights, nnet2.fc2.bias, nnet3.fc3.weights, and nnet3.fc3.bias. So that when we use super call, the model will be loaded correctly.

Below is the complete script for the model wrapper, where we provide the datasets via make_dataloaders functions and add the multi-module saving and loading code discussed above:

class MyEnsembleWrapper(ModelWrapper):
    def make_dataloader_train(self, run_config: ParallelConfig):
        return DataLoader(train_dataset, batch_size=run_config.train_config.batch_size, shuffle=True)

    def make_dataloader_val(self, run_config: ParallelConfig):
        return DataLoader(test_dataset, batch_size=run_config.train_config.batch_size, shuffle=False)

    def save_dict(self):
        saved_dict = super().save_dict()
        model_state_dict = {
            "nnet1": self.model.nnet1.state_dict(),
            "nnet2": self.model.nnet2.state_dict(),
            "nnet3": self.model.nnet3.state_dict(),
            }
        saved_dict["model"] = model_state_dict

        return saved_dict

    def create_model(self, save_dict=None, strict_load=True):
        if save_dict is not None:
            nd_save_dict = {}
            for nnet in save_dict["model"]:
                for key in save_dict["model"][nnet]:
                    new_key = nnet + "." + key
                    nd_save_dict[new_key] = save_dict["model"][nnet][key]
            save_dict["model"] = nd_save_dict
        super().create_model(save_dict=save_dict, strict_load=True)

Custom evaluation (Optional)#

We will use accuracy and f1 as evaluation metrics

def my_accuracy(preds, labels):
    return accuracy_score(preds.flatten(), labels.flatten())

def my_f1_score(preds, labels):
    return f1_score(preds.flatten(), labels.flatten(), average='weighted')

Launch the ablation experiment#

Everything is ready, now we can launch the ablation experiment.

shutil.rmtree(run_config.experiment_dir, ignore_errors=True)

wrapper = MyEnsembleWrapper(
    model_class=MyEnsemble,
)

ablator = ParallelTrainer(
    wrapper=wrapper,
    run_config=run_config,
)
ablator.launch(working_directory = os.getcwd())

When the experiment finishes, you should see the model checkpoints in each of the trial output folder, and you can verify the multi-module model state structure by loading it with torch: torch.load(<path_to_checkpoint>)

You can also rerun the ablation experiment using these checkpoints, by specifying init_chkpt parameter in the running config to load the model from the checkpoint we saved earlier. Remember to store the checkpoint somewhere else different than the experiment directory that you’re using for this rerun.

run_config = ParallelConfig(
    train_config=train_config,
    model_config=model_config,
    metrics_n_batches = 300,
    experiment_dir = "/tmp/experiments/",
    device="cuda",
    amp=True,
    random_seed = 42,
    total_trials = 5,
    concurrent_trials = 2,
    search_space = search_space,
    optim_metrics = {"val_loss": "min"},
    optim_metric_name = "val_loss",
    gpu_mb_per_experiment = 512,
    init_chkpt="/tmp/experiments/<trial-id>/best_checkpoints/<checkpoint file .pt>"
)

shutil.rmtree(run_config.experiment_dir, ignore_errors=True)

wrapper = MyEnsembleWrapper(
    model_class=MyEnsemble,
)

ablator = ParallelTrainer(
    wrapper=wrapper,
    run_config=run_config,
)

metrics = ablator.launch(working_directory = os.getcwd())

The experiment will load the model from the checkpoint and continue training. And that’s it, this is an example that shows how customizable ablator is, so that you can customize it to fit your needs.

Conclusion#

In this tutorial, we have shown you the flexibility of ablator framework. With some understanding of the model wrapper, you will be able to upgrade the framework to your use cases.