{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Saving and loading multi-module models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ablator is a flexible framework, you can overwrite its functions to extend to your use case. In this tutorial, we will show how ablator can be customized so that we can save and load multi-module models. Saving multi-module models is helpful when you have a model that consists of multiple modules, and you want to save the entire model to a file and load it back later on. Sample use cases include encoder and decoder blocks in a transformer model, ensemble models, etc.\n", "\n", "For demonstration purpose, in this tutorial, we will create an ensemble of 3 simple 1-hidden layer neural networks, train them on the breast cancer dataset for 30 epochs, save the ensemble as a 3-module model, load it back and train for another 30 epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us first import necessary modules\n", "\n", "```python\n", "from ablator import ModelConfig, OptimizerConfig, TrainConfig, ParallelConfig\n", "from ablator import ModelWrapper, ParallelTrainer\n", "from ablator.main.configs import SearchSpace\n", "\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from sklearn.datasets import load_breast_cancer\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.metrics import f1_score, accuracy_score # for custom evaluation functions\n", "\n", "import shutil\n", "import os\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "class BreastCancerDataset(Dataset):\n", " def __init__(self, data, targets):\n", " self.data = data\n", " self.scaler = MinMaxScaler()\n", " self.data = self.scaler.fit_transform(self.data)\n", " self.targets = targets\n", "\n", " def __getitem__(self, index):\n", " x = self.data[index]\n", " y = self.targets[index]\n", " return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", "# Load dataset from scikit-learn\n", "breast_cancer = load_breast_cancer()\n", "data = breast_cancer.data\n", "targets = breast_cancer.target\n", "\n", "# Split the data into train and test sets\n", "train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.2, random_state=42)\n", "\n", "# Create train and test datasets\n", "train_dataset = BreastCancerDataset(train_data, train_targets)\n", "test_dataset = BreastCancerDataset(test_data, test_targets)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the ensemble model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simple 1-hidden layer neural network module" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a simple NN module with a hidden layer of size 50, ReLu activation function is applied at the hidden layer. The output layer size is two, corresponding to the two classes of the dataset.\n", "\n", "```python\n", "class NNet(nn.Module):\n", " def __init__(self) -> None:\n", " super().__init__()\n", " self.fc1 = nn.Linear(30, 50)\n", " self.fc2 = nn.Linear(50, 2)\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", "\n", " return x\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Assemble the modules\n", "\n", "We create the ensemble model named `MyEnsemble`. It consists of 3 separate neural networks, the final prediction probability is calculated by simply aggregating outputs from the 3 networks, and applying softmax to it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "class MyEnsemble(nn.Module):\n", " def __init__(self, *args, **kwargs) -> None:\n", " super().__init__()\n", " self.nnet1 = NNet()\n", " self.nnet2 = NNet()\n", " self.nnet3 = NNet()\n", " \n", " def forward(self, x, labels=None):\n", " x1 = self.nnet1(x) \n", " x2 = self.nnet2(x) \n", " x3 = self.nnet3(x)\n", "\n", " ensemble = x1+x2+x3\n", " ensemble = F.softmax(ensemble, dim=1)\n", "\n", " loss = F.cross_entropy(ensemble, labels)\n", " preds = torch.argmax(ensemble, dim=1)\n", " \n", " preds = preds.reshape(-1, 1)\n", " labels = labels.reshape(-1, 1)\n", "\n", " return {\"preds\": preds, \"labels\": labels}, loss\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure the experiment\n", "Now it's time we set up the ablation experiment by defining configurations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model configuration\n", "Since we're not ablating the model architecture, no custom model configuration is needed, so model configuration is just an empty one.\n", "```python\n", "model_config = ModelConfig()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimizer configuration\n", "We will use adam optimizer, with the learning rate initialized to 0.01" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "optimizer_config = OptimizerConfig(\n", " name=\"adam\",\n", " arguments={\"lr\": 0.001}\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also define a search space for different learning rate values:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "search_space = {\n", " \"train_config.optimizer_config.arguments.lr\": SearchSpace(\n", " value_range = [0.001, 0.01],\n", " value_type = 'float'\n", " )\n", "}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "train_config = TrainConfig(\n", " dataset=\"breast-cancer\",\n", " batch_size=32,\n", " epochs=30,\n", " optimizer_config=optimizer_config,\n", " scheduler_config=None,\n", " rand_weights_init = True\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Running configuration (parallel config)\n", "Combine model configuration, train configuration, and search space into the running configuration:\n", "\n", "```python\n", "run_config = ParallelConfig(\n", " train_config=train_config,\n", " model_config=model_config,\n", " metrics_n_batches = 800,\n", " experiment_dir = \"/tmp/experiments/\",\n", " device=\"cuda\",\n", " amp=True,\n", " random_seed = 42,\n", " total_trials = 5,\n", " concurrent_trials = 3,\n", " search_space = search_space,\n", " optim_metrics = {\"val_loss\": \"min\"},\n", " gpu_mb_per_experiment = 1024\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model wrapper\n", "Other than overwriting the data loaders functions and evaluation functions, we modify the model saving and loading functions of `ModelWrapper` so that ablator saves the model as a 3-module model:\n", "\n", "```python\n", "{\n", " \"model\": {\n", " \"nnet1\": self.model.nnet1.state_dict(),\n", " \"nnet2\": self.model.nnet2.state_dict(),\n", " \"nnet3\": self.model.nnet3.state_dict()\n", " }\n", "}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-module model saving\n", "We will overwrite `ModelWrapper.save_dict()` function to save the entire model as a dictionary of modules.\n", "```python\n", " def save_dict(self):\n", " saved_dict = super().save_dict()\n", " model_state_dict = {\n", " \"nnet1\": self.model.nnet1.state_dict(),\n", " \"nnet2\": self.model.nnet2.state_dict(),\n", " \"nnet3\": self.model.nnet3.state_dict(),\n", " }\n", " saved_dict[\"model\"] = model_state_dict\n", " \n", " return saved_dict\n", "```\n", "Originally, ablator framework saves the model as a whole, i.e., `saved_dict[\"model\"] = self.model.state_dict()`.\n", "\n", "In our example, as you can see, modules `nnet1`, `nnet2`, and `nnet3` from `MyEnsemble` can be accessed via `self.model.nnet1`, `self.model.nnet2`, and `self.model.nnet3` respectively, and we will save these modules' state dictionaries into `saved_dict[\"model\"]`.\n", "\n", "This way, the model saved will be a dictionary of modules:\n", "```python\n", "saved_dict = {\n", " \"model\": {\n", " \"nnet1\": {\"fc1.weights\": weights, \"fc1.bias\": bias},\n", " \"nnet2\": {\"fc2.weights\": weights, \"fc2.bias\": bias},\n", " \"nnet3\": {\"fc3.weights\": weights, \"fc3.bias\": bias},\n", " },\n", " ...\n", "}\n", "```\n", "After running the experiment, you can use `torch.load()` to verify this, where `` is the path to one of the models that are saved in the experiment directory." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-module model loading\n", "Now that we have saved a multi-module model, we also need to change how ablator loads the model. We do this by overwriting `ModelWrapper.create_model()` function.\n", "```python\n", " def create_model(\n", " self,\n", " save_dict: dict[str, ty.Any] | None = None,\n", " strict_load: bool = True,\n", " ) -> None:\n", " if save_dict is not None:\n", " nd_save_dict = {}\n", " for nnet in save_dict[\"model\"]:\n", " for key in save_dict[\"model\"][nnet]:\n", " new_key = nnet + \".\" + key\n", " nd_save_dict[new_key] = save_dict[\"model\"][nnet][key]\n", " save_dict[\"model\"] = nd_save_dict\n", " super().create_model(save_dict=save_dict, strict_load=True)\n", "```\n", "Originally, ablator framework will load the model as a whole, i.e., `model.load_state_dict(save_dict[\"model\"])`.\n", "\n", "So in our example, as you can see, the keys will be updated to `nnet1.fc1.weights`, `nnet1.fc1.bias`, `nnet2.fc2.weights`, `nnet2.fc2.bias`, `nnet3.fc3.weights`, and `nnet3.fc3.bias`. So that when we use super call, the model will be loaded correctly." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is the complete script for the model wrapper, where we provide the datasets via `make_dataloaders` functions and add the multi-module saving and loading code discussed above:\n", "\n", "```python\n", "class MyEnsembleWrapper(ModelWrapper):\n", " def make_dataloader_train(self, run_config: ParallelConfig):\n", " return DataLoader(train_dataset, batch_size=run_config.train_config.batch_size, shuffle=True)\n", "\n", " def make_dataloader_val(self, run_config: ParallelConfig):\n", " return DataLoader(test_dataset, batch_size=run_config.train_config.batch_size, shuffle=False)\n", " \n", " def save_dict(self):\n", " saved_dict = super().save_dict()\n", " model_state_dict = {\n", " \"nnet1\": self.model.nnet1.state_dict(),\n", " \"nnet2\": self.model.nnet2.state_dict(),\n", " \"nnet3\": self.model.nnet3.state_dict(),\n", " }\n", " saved_dict[\"model\"] = model_state_dict\n", "\n", " return saved_dict\n", "\n", " def create_model(self, save_dict=None, strict_load=True):\n", " if save_dict is not None:\n", " nd_save_dict = {}\n", " for nnet in save_dict[\"model\"]:\n", " for key in save_dict[\"model\"][nnet]:\n", " new_key = nnet + \".\" + key\n", " nd_save_dict[new_key] = save_dict[\"model\"][nnet][key]\n", " save_dict[\"model\"] = nd_save_dict\n", " super().create_model(save_dict=save_dict, strict_load=True)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom evaluation (Optional)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use accuracy and f1 as evaluation metrics\n", "\n", "```python\n", "def my_accuracy(preds, labels):\n", " return accuracy_score(preds.flatten(), labels.flatten())\n", "\n", "def my_f1_score(preds, labels):\n", " return f1_score(preds.flatten(), labels.flatten(), average='weighted')\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch the ablation experiment\n", "Everything is ready, now we can launch the ablation experiment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "shutil.rmtree(run_config.experiment_dir, ignore_errors=True)\n", "\n", "wrapper = MyEnsembleWrapper(\n", " model_class=MyEnsemble,\n", ")\n", "\n", "ablator = ParallelTrainer(\n", " wrapper=wrapper,\n", " run_config=run_config,\n", ")\n", "ablator.launch(working_directory = os.getcwd(), ray_head_address=\"auto\")\n", "```\n", "\n", "When the experiment finishes, you should see the model checkpoints in each of the trial output folder, and you can verify the multi-module model state structure by loading it with torch: `torch.load()`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also rerun the ablation experiment using these checkpoints, by specifying `init_chkpt` parameter in the running config to load the model from the checkpoint we saved earlier. Remember to store the checkpoint somewhere else different than the experiment directory that you're using for this rerun." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "run_config = ParallelConfig(\n", " train_config=train_config,\n", " model_config=model_config,\n", " metrics_n_batches = 800,\n", " experiment_dir = \"/tmp/experiments/\",\n", " device=\"cuda\",\n", " amp=True,\n", " random_seed = 42,\n", " total_trials = 5,\n", " concurrent_trials = 3,\n", " search_space = search_space,\n", " optim_metrics = {\"val_loss\": \"min\"},\n", " gpu_mb_per_experiment = 1024,\n", " init_chkpt=\"/tmp/experiments1/experiment_7ae3_9991/2ca5_9991/best_checkpoints/MyEnsemble_0000000210.pt\"\n", ")\n", "\n", "shutil.rmtree(run_config.experiment_dir, ignore_errors=True)\n", "\n", "wrapper = MyEnsembleWrapper(\n", " model_class=MyEnsemble,\n", ")\n", "\n", "ablator = ParallelTrainer(\n", " wrapper=wrapper,\n", " run_config=run_config,\n", ")\n", "\n", "metrics = ablator.launch(working_directory = os.getcwd(), ray_head_address=None)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The experiment will load the model from the checkpoint and continue training. And that's it, this is an example that shows how customizable ablator is, so that you can customize it to fit your needs. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we have shown you the flexibility of ablator framework. With some understanding of the model wrapper, you will be able to upgrade the framework to your use cases." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }