ablator.main.model package#
Submodules#
ablator.main.model.main module#
- class ablator.main.model.main.ModelBase(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ABCBase class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality.
Notes
Class properties are simply listed by name. Please check out property docstring for more information.
Users must implement the abstract methods to customize the model’s behavior.
Mixed precision training enables some operations to use the
torch.float32datatype and other operations use lower precision floating point datatypetorch.float16. This is for saving time and reducing memory usage. Ordinarily, “automatic mixed precision training” means training withtorch.autocastandtorch.cuda.amp.GradScalertogether. More information: https://pytorch.org/docs/stable/amp.html
- Attributes:
- model_classType[nn.Module]
The class definition of the model’s structure, which is a subclass of
nn.Module.- run_configRunConfig
An instance of
RunConfigcontaining configuration details.- train_dataloaderDataLoader
A DataLoader object responsible for model training.
- val_dataloaderOptional[DataLoader]
An optional DataLoader object used for model evaluation.
- test_dataloaderOptional[DataLoader]
An optional DataLoader object used for model testing.
- loggerUnion[SummaryLogger, Dummy]
Records information on the program’s operation and model training, such as progress and performance metrics.
- devicestr
The type of device used for running the experiment. i.e.
"cuda","cpu","cuda:0".- model_dirPath
The model directory.
- experiment_dirPath
The experiment directory.
- autocasttorch.autocast
Enables autocasting for chosen regions. Autocasting automatically chooses the precision for GPU operations to improve performance while maintaining accuracy.
- verbosebool
If
True, prints additional information while training. Only applied for the master process.- ampbool
If
True, apply automatic mixed precision training, otherwise default precision.- random_seedOptional[int]
Sets the seed for generating random numbers.
- progress_barUnion[ProgressBar, Dummy]
An optional instance of
ProgressBarthat displays real-time information during training. e.g. time remaining. Only applied for the master process.- current_checkpointOptional[Path]
Directory for the current checkpoint file, by default None.
- train_metricsMetrics
Training metrics including model information. i.e. learning rate and loss value.
- eval_metricsMetrics | None
Evaluation metrics for when a
val_dataloaderis provided.- current_statedict
The currrent state of the model, including run_config, metrics and other necessary states.
- learning_ratefloat
The current learning rate.
- total_stepsint
The total steps for the training process.
- epochsint
The total epochs for the training process.
- current_iterationint
The current iteration of training.
- best_iterationint
The iteration with the best loss value.
- best_lossfloat
The lowest loss value encountered during training.
- abstract checkpoint(is_best=False)[source]#
Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the
checkpointmethod in theModelWrapperclass.- Parameters:
- is_bestbool, optional
Indicates if the current checkpoint is the best model so far, by default
False.
- abstract config_parser(run_config: RunConfig)[source]#
Abstract method to parse the provided configuration.
Must be implemented by subclasses. Example implementation: Please see the
make_dataloadersmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- abstract create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#
Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the
create_modelmethod in theModelWrapperclass.- Parameters:
- save_dictdict[str, ty.Any] | None, optional
A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default
None.- strict_loadbool, optional
If True, the model will be loaded strictly, ensuring that the saved state matches the model’s structure exactly. If False, the model can be loaded with a partially matching state, by default
True.
- property current_epoch: int#
Calculates and returns the current epoch during training.
- Returns:
- int
The current epoch number.
- property epoch_len#
Returns the length of an epoch, which is the number of batches in the
train_dataloader.- Returns:
- int
The length of an epoch, represented as the number of batches in the
train_dataloader.
- Raises:
- AssertionError
If the
train_dataloaderis not defined or its length is 0.
- property eval_itr#
Calculate the interval between evaluations.
- Returns:
- int
The interval between evaluations.
- abstract evaluate(run_config: RunConfig)[source]#
Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the
evaluatemethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- abstract evaluation_functions() dict[str, collections.abc.Callable] | None[source]#
Abstract method to create and return a dictionary of evaluation functions used during training and validation.
Must be implemented by subclasses. Example implementation: Please see the
evaluation_functionsmethod in theModelWrapperclass.- Returns:
- dict[str, Callable] | None
A dictionary containing evaluation functions as values and their names as keys.
- abstract load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None[source]#
Abstract method to load the model and its state from a given save dictionary.
Must be implemented by subclasses. Example implementation: Please see the
load_checkpointmethod in theModelWrapperclass.- Parameters:
- save_dictdict[str, ty.Any]
A dictionary containing the saved model state and other necessary information.
- model_onlybool, optional, default=False
If
True, only the model’s weights will be loaded, ignoring other state information.
- property log_itr#
Calculate the interval between logging steps.
- Returns:
- int
The interval between logging steps.
- abstract make_dataloaders(run_config: RunConfig)[source]#
Abstract method to create dataloaders for the training, validation, and testing datasets.
This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided
run_config.Must be implemented by subclasses. Example implementation: Please see the
make_dataloadersmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- abstract save_dict() dict[str, Any] | None[source]#
Abstract method to create and return a save dictionary containing the model’s state and other necessary information.
Must be implemented by subclasses. Example implementation: Please see the
save_dictmethod in theModelWrapperclass.- Returns:
- dict[str, ty.Any] | None
A dictionary containing the saved model state and other necessary information.
- abstract train(run_config: RunConfig, smoke_test: bool = False)[source]#
Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the
trainmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.- smoke_testbool, optional
Whether to run as a smoke test, by default
False.
- property train_stats: OrderedDict#
Returns an ordered dictionary containing the current training statistics.
- Returns:
- OrderedDict
An ordered dictionary with the following keys and values:
learning_rate: The current learning rate.
total_steps: The total steps for the training process.
epochs: The number of epochs for training.
current_epoch: The current epoch during training.
current_iteration: The current iteration during training.
best_iteration: The iteration with the best loss value so far.
best_loss: The best (lowest) loss value achieved during training.
- property uid#
Returns a unique identifier (UID) for the current run configuration.
- Returns:
- str
A string representing the unique identifier of the current run configuration.
ablator.main.model.wrapper module#
- class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ModelBaseA wrapper around
model_classthat removes training boiler-plate code. Its functions are over-writable to support for custom use-cases. Oncemake_dataloader_trainis overriden to provide a training dataset,ModelWrapperobject will be passed to the trainers (ProtoTrainerorParallelTrainer) along with running configuration to launch the experiment.- Attributes:
- model_class: torch.nn.Module
The model class to wrap.
- model: torch.nn.Module
The model created from the model class or checkpoint
- optimizer: Optimizer
The optimizer created from the optimizer config or checkpoint
- scaler: GradScaler
The scaler created from the scaler config or checkpoint
- scheduler: Scheduler
The scheduler created from the scheduler config or checkpoint
- apply_loss(model: Module, loss: Tensor | None, optimizer: Optimizer, scaler: GradScaler, scheduler: _LRScheduler | ReduceLROnPlateau | Any | None) float | None[source]#
Calculate the loss and apply the gradients, call
optimizer.step()andscheduler.step().- Parameters:
- model: nn.Module
The model to apply the loss to.
- loss: torch.Tensor | None
The loss to apply.
- optimizer: Optimizer
The optimizer to step.
- scaler: torch.cuda.amp.GradScaler
The scaler to use for mixed precision training.
- scheduler: ty.Optional[Scheduler]
The scheduler to step.
- Returns:
- float | None
The loss value.
- checkpoint(is_best=False)[source]#
Save a checkpoint of the model.It will use the class name of the model as the filename.
- Parameters:
- is_best: bool
Whether this is the best model so far.
- config_parser(run_config: RunConfig)[source]#
You can overwrite this function to initialize
Derivedproperties that are not decided until the experiment is launched.Examples
For example, in GPT2 model, we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run ablation study on different tokenizers:
>>> class MyLMWrapper(ModelWrapper): ... def config_parser(self, run_config: RunConfig): ... run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer) ... return run_config
- create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#
Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from config. You can overwrite this function and
save_dict()function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.- Parameters:
- save_dict: dict[str, ty.Any]
The saved checkpoint dictionary to load from.
- strict_load: bool
Whether to load the model strictly or not.
- create_optimizer(model: Module, optimizer_config: OptimizerConfig | None = None, optimizer_state: dict[str, Any] | None = None) Optimizer[source]#
Creates the optimizer from the saved state or from config.
- Parameters:
- model: nn.Module
The model to create the optimizer for.
- optimizer_config: OptimizerConfig
The optimizer config to create the optimizer from.
- optimizer_state: dict[str, ty.Any]
The optimizer state to load the optimizer from.
- Returns:
- optimizer: Optimizer
The optimizer.
- create_scaler(scaler_state: dict | None = None) GradScaler[source]#
Creates the scaler from the saved state or from config.
- Parameters:
- scaler_state: dict[str, ty.Any]
The scaler state to load the scaler from.
- Returns:
- scaler: GradScaler
The scaler.
- create_scheduler(model: Module, optimizer: Optimizer, scheduler_config: SchedulerConfig | None = None, scheduler_state: dict | None = None) _LRScheduler | ReduceLROnPlateau | Any | None[source]#
Creates the scheduler from the saved state or from config.
- Parameters:
- model: nn.Module
The model to create the scheduler for.
- optimizer: Optimizer
The optimizer to create the scheduler for.
- scheduler_config: SchedulerConfig
The scheduler config to create the scheduler from.
- scheduler_state: dict[str, ty.Any]
The scheduler state to load the scheduler from.
- Returns:
- scheduler: Scheduler
The scheduler.
- property epochs#
The total number of epochs.
- final eval(smoke_test=False)[source]#
Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step. It also check if it is early stopping (check Model Configuration module for more details).
- final evaluate(run_config: RunConfig)[source]#
Evaluate the model after training on the test and validation sets.
- Parameters:
- run_config: RunConfig
The run config to use for evaluation.
- evaluation_functions() dict[str, collections.abc.Callable] | None[source]#
You can overwrite this function and return the evaluation functions callables that will be used to evaluate experiment metrics.
- Returns:
- dict[str, Callable]
The evaluation functions to use. Also see
Metricsfor details.
Examples
Define the callables:
>>> def my_accuracy(y_true, y_pred): >>> return accuracy_score(y_true.flatten(), y_pred.flatten()) >>> def my_f1_score(y_true, y_pred): >>> return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
Returns callables in
evaluation_functions:
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... ) >>> def evaluation_functions(self): >>> return { ... "accuracy": my_accuracy, ... "f1": my_f1_score ... }
Note that the callable’s parameter names must match the model’s forward output. In our example,
y_trueandy_predmust be returned by the model’s forward method:
>>> class MyModel(nn.Module): >>> def __init__(self, config: CustomModelConfig) -> None: >>> super().__init__() >>> self.model = FashionMNISTModel(config) >>> self.loss = nn.CrossEntropyLoss() >>> def forward(self, x, labels=None): >>> out = self.model(x) >>> loss = None >>> if labels is not None: >>> loss = self.loss(out, labels) >>> out = out.argmax(dim=-1) >>> return {"y_pred": out, "y_true": labels}, loss
- load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None[source]#
Loads the checkpoint from the save dict.
- Parameters:
- save_dict: dict[str, ty.Any]
The save dict to load the checkpoint from.
- model_only: bool
Whether to load only the model or include scheduler, optimizer and scaler.
Notes
This method is the implementation of the abstract method in the base class.
- log()[source]#
Log if the current iteration is a logging step. It also evaluate training metrics for logging.
- log_step()[source]#
A single step for logging.
Notes
This method is update the logger with the current metrics and log a status message.
- make_dataloader_test(run_config: RunConfig) DataLoader | None[source]#
Function to make the test dataloader. You can overwrite this function and return the test dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The test dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_test(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]#
Function to make the training dataloader. You must overwrite this function and return the training dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader
The training dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloader_val(run_config: RunConfig) DataLoader | None[source]#
Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The validation dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloaders(run_config: RunConfig) None[source]#
This function is done post-initialization because otherwise the dataloaders are pickled with the object when running distributed.
- property metrics: dict[str, float]#
The metrics of the current training state. If
eval_metricsare defined (e.g. a validation dataloader was provided) then they are combined and a val_ and train_ label is prepended. The current training statistics are also combined- Returns:
- dict[str, float]
A dictionary with keys the metric name and the value corresponding to the metric.
- final mock_train(run_config: RunConfig | None = None, run_async=True, block: bool = True) Process | dict[str, float][source]#
Mock train the model as a smoke test
- Parameters:
- run_config: RunConfig
The run config to use for the mock train.
- run_async: bool
Whether to run the mock train in a separate process.
- block: bool
Whether to block the current process until the mock train is finished.
- Returns:
- p: mp.Process
The process running the mock train.
- metrics: dict[str, float]
The metrics from the mock train.
- property model_config: ModelConfig#
- model_step(model: Module, batch: Iterable) tuple[dict[str, torch.Tensor] | None, torch.Tensor | None][source]#
A single inference step for the model.
- Parameters:
- model: nn.Module
The model to train.
- batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- out: tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]
The output of the model,contains current predictions and loss of the model
- save_dict() dict[str, Any][source]#
Save the current state of the trainer, including model parameters, the current states of the optimizer, scaler, and scheduler. You can overwrite this function and
create_model()to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.- Returns:
- dict[str, ty.Any]
The current state of the trainer.
- status_message() str[source]#
Return a string generated from dictionary of current metrics, including all the static metrics and moving average metrics.
- Returns:
- str
The status message.
- to_device(data: Iterable, device=None) Iterable[source]#
Moves the data to the specified device.
- Parameters:
- data: Iterable
The data to move to the device.
- device: ty.Optional[ty.Union[torch.device, str]]
The device to move the data to. If
None, the device specified in the config is used.
- Returns:
- data: Iterable
The data on the device.
- property total_steps#
The total number of steps for training.
- final train(run_config: ~ablator.config.proto.RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False, remote_progress_bar: <ablator.utils.progress_bar.ActorClass(RemoteProgressBar) object at 0x7f32bb7ceb90> | None = None) dict[str, float][source]#
Initialize states and train the model. When keyboard interrupts, saves a checkpoint
- Parameters:
- run_configRunConfig
The run config to use for training.
- smoke_testbool, default=False
Whether to run a smoke test.
- debugbool, default=False
Whether to run in debug mode.
- resumebool, default=False
Whether to resume training the model from existing checkpoints and existing experiment state.
- remote_progress_barRemoteProgressBar, optional
Optionally, we can pass a remote progress bar to report progress of the training.
- Returns:
- Metrics
The metrics from the training.
- property train_config: TrainConfig#
- train_loop(smoke_test=False)[source]#
Train the model in many steps, evaluate the model and log the metrics for each iteration. metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.
- Parameters:
- smoke_test: bool
Whether to run a smoke test.
- final train_step(batch: Iterable) tuple[dict[str, torch.Tensor] | None, dict[str, Any]][source]#
A single step for training. It also updates learning rate with scheduler.
- Parameters:
- batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- outputs: dict[str, torch.Tensor] | None
The output of the model.
- train_metrics: dict[str, ty.Any]
The training metrics.
- update_status()[source]#
Update the metrics with current training stats, and then all metrics (static and moving average) will be set as description for the
tqdmprogress.
- validation_loop(model: Module, dataloader: DataLoader, metrics: Metrics, subsample: float = 1.0, smoke_test: bool = False) dict[str, float][source]#
Validate the model on data on dataloader and store results on metrics.
- Parameters:
- model: nn.Module
The model to validate.
- dataloader: DataLoader
The dataloader to use for validation.
- metrics: Metrics
The metrics to use for validation.
- subsample: float
The fraction of the dataloader to use for validation.
- smoke_test: bool
Whether to execute this function as a smoke test. If
True, only one iteration will be performed, which is useful for quickly checking if the code runs without errors. Default isFalse.
- Returns:
- dict[str, float]
The metrics from the validation.