ablator.main.model package#
Submodules#
ablator.main.model.main module#
- class ablator.main.model.main.ModelBase(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ABCBase class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality.
- Parameters:
- model_classtype[nn.Module]
The base class for user’s model, which defines the neural network.
Notes
Class properties are simply listed by name. Please check out property docstring for more information.
Users must implement the abstract methods to customize the model’s behavior.
Mixed precision training enables some operations to use the
torch.float32datatype and other operations use lower precision floating point datatypetorch.float16. This is for saving time and reducing memory usage. Ordinarily, “automatic mixed precision training” means training withtorch.autocastandtorch.cuda.amp.GradScalertogether. More information: https://pytorch.org/docs/stable/amp.html
- Attributes:
- model_classType[nn.Module]
The class definition of the model’s structure, which is a subclass of
nn.Module.- run_configRunConfig
An instance of
RunConfigcontaining configuration details.- train_dataloaderDataLoader
A DataLoader object responsible for model training.
- val_dataloaderOptional[DataLoader]
An optional DataLoader object used for model evaluation.
- test_dataloaderOptional[DataLoader]
An optional DataLoader object used for model testing.
- loggerUnion[SummaryLogger, Dummy]
Records information on the program’s operation and model training, such as progress and performance metrics.
- devicestr
The type of device used for running the experiment. i.e.
"cuda","cpu","cuda:0".- model_dirPath
The model directory.
- experiment_dirPath
The experiment directory.
- verbosebool
If
True, prints additional information while training. Only applied for the master process.- ampbool
If
True, apply automatic mixed precision training, otherwise default precision.- random_seedOptional[int]
Sets the seed for generating random numbers.
- progress_barUnion[ProgressBar, Dummy]
An optional instance of
ProgressBarthat displays real-time information during training. e.g. time remaining. Only applied for the master process.- current_checkpointOptional[Path]
Directory for the current checkpoint file, by default
None.- train_metricsMetrics
Training metrics including model information. i.e. learning rate and loss value.
- eval_metricsMetrics | None
Evaluation metrics for when a
val_dataloaderis provided.- current_statedict
The currrent state of the model, including run_config, metrics and other necessary states.
- learning_ratefloat
The current learning rate.
- total_stepsint
The total steps for the training process.
- epochsint
The total epochs for the training process.
- current_iterationint
The current iteration of training.
- best_iterationint
The iteration with the best loss value.
- best_metricsdict[str, float]
The lowest optim values encountered during training.
- optim_metric_namestr | None
The name of the optimization metric.
- optim_metric_directionOptim | None
The optimization direction
- abstract checkpoint(is_best: bool = False)[source]#
Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the
checkpointmethod in theModelWrapperclass.- Parameters:
- is_bestbool
Indicates if the current checkpoint is the best model so far, by default
False.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- abstract config_parser(run_config: RunConfig)[source]#
Abstract method to parse the provided configuration.
Must be implemented by subclasses. Example implementation: Please see the
make_dataloadersmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- abstract create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#
Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the
create_modelmethod in theModelWrapperclass.- Parameters:
- save_dictdict[str, ty.Any] | None
A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default
None.- strict_loadbool
If True, the model will be loaded strictly, ensuring that the saved state matches the model’s structure exactly. If False, the model can be loaded with a partially matching state, by default
True.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- property current_epoch: int#
Calculates and returns the current epoch during training.
- Returns:
- int
The current epoch number.
- property epoch_len: int#
Returns the length of an epoch, which is the number of batches in the
train_dataloader.- Returns:
- int
The length of an epoch, represented as the number of batches in the
train_dataloader.
- Raises:
- RuntimeError
If the
train_dataloaderis not defined or its length is 0.
- property eval_itr: int#
Calculate the interval between evaluations.
- Returns:
- int
The interval between evaluations.
- abstract evaluate(run_config: RunConfig)[source]#
Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the
evaluatemethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- abstract evaluation_functions() dict[str, collections.abc.Callable] | None[source]#
Abstract method to create and return a dictionary of evaluation functions used during training and validation.
Must be implemented by subclasses. Example implementation: Please see the
evaluation_functionsmethod in theModelWrapperclass.- Returns:
- dict[str, Callable] | None
A dictionary containing evaluation functions as values and their names as keys.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- init_state(run_config: ~ablator.config.proto.RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False, remote_progress_bar: <ablator.utils.progress_bar.ActorClass(RemoteProgressBar) object at 0x7f0d3ce62290> | None = None, from_chkpt: str | ~pathlib.Path | None = None, data_lock: ~ablator.utils.base.Lock | None = None)[source]#
Initializes the state of the wrapper based on provided configuration and parameters. The lazy-initialization of the wrapper is neccessary, as the wrapper must be pickable. Initializing some objects (e.g. Dataloaders) leads to the opposite.
- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.- smoke_testbool
Whether to run as a smoke test, by default
False.- debugbool
If True, disables logging and model directory creation, by default
False.- resumebool
If True, tries to resume training from a checkpoint, by default
False.- remote_progress_barty.Optional[RemoteProgressBar]
A remote progress bar can be used to report metrics from the internal progress bar
- from_chkpt: Path | str | None, optional
Path to the checkpoint to initialize the state from.
- data_lock: ty.Optional[Lock], optional
Use a Lock to avoid downloading data concurrently.
- Raises:
- RuntimeError
if the state is already initialized and smoke_test, debug and resume flag are
False
- abstract load_checkpoint(save_dict: dict[str, Any], model_only: bool = False)[source]#
Abstract method to load the model and its state from a given save dictionary.
Must be implemented by subclasses. Example implementation: Please see the
load_checkpointmethod in theModelWrapperclass.- Parameters:
- save_dictdict[str, ty.Any]
A dictionary containing the saved model state and other necessary information.
- model_onlybool
If
True, only the model’s weights will be loaded, ignoring other state information, by defaultFalse.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- property log_itr: int#
Calculate the interval between logging steps.
- Returns:
- int
The interval between logging steps.
- abstract make_dataloaders(run_config: RunConfig)[source]#
Abstract method to create dataloaders for the training, validation, and testing datasets.
This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided
run_config.Must be implemented by subclasses. Example implementation: Please see the
make_dataloadersmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- abstract save_dict() dict[str, Any] | None[source]#
Abstract method to create and return a save dictionary containing the model’s state and other necessary information.
Must be implemented by subclasses. Example implementation: Please see the
save_dictmethod in theModelWrapperclass.- Returns:
- dict[str, ty.Any] | None
A dictionary containing the saved model state and other necessary information.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- abstract train(run_config: RunConfig, smoke_test: bool = False)[source]#
Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the
trainmethod in theModelWrapperclass.- Parameters:
- run_configRunConfig
An instance of
RunConfigcontaining configuration details.- smoke_testbool
Whether to run as a smoke test, by default
False.
- Raises:
- NotImplementedError
If this method is not implemented by the subclasses.
- property train_stats: OrderedDict#
Returns an ordered dictionary containing the current training statistics.
- Returns:
- OrderedDict
An ordered dictionary with the following keys and values:
learning_rate: The current learning rate.
total_steps: The total steps for the training process.
epochs: The number of epochs for training.
current_epoch: The current epoch during training.
current_iteration: The current iteration during training.
best_iteration: The iteration with the best loss value so far.
best_*: The best (lowest) optim metric value achieved during training, for example best_val_loss
- property uid: str#
Returns a unique identifier (UID) for the current run configuration.
- Returns:
- str
A string representing the unique identifier of the current run configuration.
ablator.main.model.wrapper module#
- class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ModelBaseA wrapper around
model_classthat removes training boiler-plate code. Oncemake_dataloader_trainis overridden to provide a training dataset, you can passModelWrapperobject to the trainers (ProtoTrainerorParallelTrainer) along with a running configuration (RunConfigorParallelConfig) to launch the experiment.The wrapper lets you customize the training process by allowing the overriding of various functions, this makes it adaptable to different training paradigms. Several customizing use cases are shown below.
- Parameters:
- model_classtype[nn.Module]
The model class to wrap.
- Attributes:
- model_classtorch.nn.Module
The model class to wrap.
- modeltorch.nn.Module
The model created from the model class or checkpoint
- optimizerOptimizer
The optimizer created from the optimizer config or checkpoint
- scalerGradScaler
The scaler created from the scaler config or checkpoint
- schedulerScheduler
The scheduler created from the scheduler config or checkpoint
- apply_loss(model: Module, loss: Tensor | None, optimizer: Optimizer, scaler: GradScaler) float | None[source]#
Calculate the loss and apply the gradients, call
optimizer.step()andscheduler.step().- Parameters:
- modelnn.Module
The model to apply the loss to.
- losstorch.Tensor | None
The loss to apply.
- optimizerOptimizer
The optimizer to step.
- scaler: torch.cuda.amp.GradScaler
The scaler to use during mixed precision training.
- Returns:
- float | None
The loss value.
- checkpoint(is_best: bool = False)[source]#
Save a checkpoint of the model.It will use the class name of the model as the filename.
- Parameters:
- is_bestbool
Whether this is the best model so far, by default
False.
- config_parser(run_config: RunConfig) RunConfig[source]#
You can overwrite this function to initialize
Derivedproperties that are not decided until the experiment is launched.- Parameters:
- run_configRunConfig
The run config for the experiment.
- Returns:
- RunConfig
Examples
For example, in GPT2 models, sometimes we need to resize its vocabulary size to match the tokenizer’s vocabulary size depending on the tokenizer used. This is decided only after the experiment has been launched. The reason for this is that you might want to run an ablation study on different tokenizers:
>>> class MyLMWrapper(ModelWrapper): ... def config_parser(self, run_config: RunConfig): ... run_config.model_config.resize_token_embedding = len(self.train_dataloader.dataset.tokenizer) ... return run_config
- create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True)[source]#
Creates the model, optimizer, scheduler, and scaler from a saved checkpoint dictionary or from configuration objects. You can overwrite this function and
save_dict()function to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example for this is shown in Saving and loading multi-module models tutorial.- Parameters:
- save_dictdict[str, ty.Any] | None
The saved checkpoint dictionary to load from, by default
None.- strict_loadbool
Whether to throw an error for mismatched keys, by default
True.
- create_optimizer(model: Module, optimizer_config: OptimizerConfig | None = None, optimizer_state: dict[str, Any] | None = None) Optimizer[source]#
Creates the optimizer from the saved state or from config.
- Parameters:
- modelnn.Module
The model to create the optimizer for.
- optimizer_configOptimizerConfig | None
The optimizer config to create the optimizer from, by default
None.- optimizer_statedict[str, ty.Any] | None
The optimizer state to load the optimizer from, by default
None.
- Returns:
- Optimizer
The optimizer.
- create_scaler(scaler_state: dict[str, Any] | None = None) GradScaler[source]#
Creates the scaler from the saved state or from config.
- Parameters:
- scaler_statety.Optional[dict[str, ty.Any]]
The scaler state to load the scaler from, optional, by default
None.
- Returns:
- GradScaler
The scaler.
- create_scheduler(model: Module, optimizer: Optimizer, scheduler_config: SchedulerConfig | None = None, scheduler_state: dict[str, Any] | None = None) _LRScheduler | ReduceLROnPlateau | Any | None[source]#
Creates the scheduler from the saved state or from config.
- Parameters:
- modelnn.Module
The model to create the scheduler for.
- optimizerOptimizer
The optimizer to create the scheduler for.
- scheduler_configSchedulerConfig | None
The scheduler config to create the scheduler from, by default
None.- scheduler_statedict[str, ty.Any] | None
The scheduler state to load the scheduler from, by default
None.
- Returns:
- ty.Optional[Scheduler]
The scheduler.
- final eval(smoke_test: bool = False)[source]#
Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step. It also check if it is early stopping (check Model Configuration module for more details).
- Parameters:
- smoke_testbool
If True, for a smoke test, by default
False.
- Raises:
- LossDivergedError
If the loss is diverged during eval.
- TrainPlateauError
If loss begins to slow down dramatically.
- final evaluate(run_config: RunConfig, chkpt: str | Path | None = None) dict[str, dict[str, Any]][source]#
Evaluate the model after training on the test and validation sets.
- Parameters:
- run_configRunConfig
The run config to use for evaluation.
- chkpt: str | Path | None
Path to the checkpoint to evaluate. If None, the latest checkpoint is evaluated, by default
None.
- Returns:
- dict[str, dict[str, ty.Any]]
Metrics
- evaluation_functions() dict[str, collections.abc.Callable] | None[source]#
You can overwrite this function and return the evaluation functions’ callables that will be used to evaluate experiment metrics.
- Returns:
- dict[str, Callable] | None
The evaluation functions to use.
Examples
Define the callables:
>>> def my_accuracy(y_true, y_pred): >>> return accuracy_score(y_true.flatten(), y_pred.flatten()) >>> def my_f1_score(y_true, y_pred): >>> return f1_score(y_true.flatten(), y_pred.flatten(), average='weighted')
Returns callables in
evaluation_functions:
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... ) >>> def evaluation_functions(self): >>> return { ... "accuracy": my_accuracy, ... "f1": my_f1_score ... }
Note that the callable’s parameter names must match the model’s forward output. In our example,
y_trueandy_predmust be returned by the model’s forward method to match withy_trueandy_predparameters ofmy_accuracyandmy_f1_scorefunctions:
>>> class MyModel(nn.Module): >>> def __init__(self, config: CustomModelConfig) -> None: >>> super().__init__() >>> self.model = FashionMNISTModel(config) >>> self.loss = nn.CrossEntropyLoss() >>> def forward(self, x, labels=None): >>> out = self.model(x) >>> loss = None >>> if labels is not None: >>> loss = self.loss(out, labels) >>> out = out.argmax(dim=-1) >>> return {"y_pred": out, "y_true": labels}, loss
- load_checkpoint(save_dict: dict[str, Any], model_only: bool = False)[source]#
Loads the checkpoint from the save dict.
- Parameters:
- save_dictdict[str, ty.Any]
The save dict to load the checkpoint from.
- model_onlybool
Whether to load only the model or include scheduler, optimizer and scaler, by default
False.
Notes
This method is the implementation of the abstract method in the base class.
- log()[source]#
Log if the current iteration is a logging step. It also evaluate training metrics for logging.
- log_step()[source]#
A single step for logging.
Notes
This method is update the logger with the current metrics and log a status message.
- make_dataloader_test(run_config: RunConfig) DataLoader | None[source]#
Function to make the test dataloader. You can overwrite this function and return the test dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The test dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_test(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]#
Function to make the training dataloader. You must overwrite this function and return the training dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader
The training dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... test_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloader_val(run_config: RunConfig) DataLoader | None[source]#
Function to make the validation dataloader. You can overwrite this function and return the validation dataloader.
- Parameters:
- run_configRunConfig
The run configuration.
- Returns:
- DataLoader | None
The validation dataloader.
Examples
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> def make_dataloader_train(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... train_dataset, ... batch_size=32, ... shuffle=True ... ) >>> def make_dataloader_val(self, run_config: CustomRunConfig): >>> return torch.utils.data.DataLoader( ... val_dataset, ... batch_size=32, ... shuffle=False ... )
- make_dataloaders(run_config: RunConfig) None[source]#
This function is done post-initialization because otherwise the dataloaders are pickled with the object when running distributed.
- Parameters:
- run_configRunConfig
The run config for the experiment.
- property metrics: dict[str, float]#
The metrics of the current training state. If
eval_metricsare defined (e.g. a validation dataloader was provided) then they are combined and a val_ and train_ label is prepended. The current training statistics are also combined- Returns:
- dict[str, float]
A dictionary with keys the metric name and the value corresponding to the metric.
- property model_config: ModelConfig#
- model_step(model: Module, batch: Iterable) tuple[dict[str, torch.Tensor] | None, torch.Tensor | None][source]#
A single inference step for the model.
- Parameters:
- modelnn.Module
The model to train.
- batchIterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]
The output of the model,contains current predictions and loss of the model
- final optim_step(optimizer: Optimizer, scaler: GradScaler, model: Module, loss: Tensor | None)[source]#
- save_dict() dict[str, Any][source]#
Save the current state of the trainer, including model parameters, and the current states of the optimizer, scaler, and scheduler. You can overwrite this function and
create_model()to customize the saving and loading of the model, optimizer, and scheduler to your needs. An example of this is shown in Saving and loading multi-module models tutorial.- Returns:
- dict[str, ty.Any]
The current state of the trainer.
model: the model’s state
optimizer: the state of optimizer
scheduler: the scheduler’s state
scaler: the state of gradScaler.
- final scheduler_step(*args, is_val_step=False)[source]#
A single scheduler step. This function accounts for when the scheduler is supposed to take a step. It reads step_when as a property from the scheduler or derives it from the train configuration. step_when is expected to be in {“train”, “epoch”, “val”}. When the scheduler is a validation scheduler, it expects some mertrics passed as arguments to the scheduler step function. e.g. wrapper.scheduler_step(0.001).
When no args is provided and it is a validation scheduler, this function is no-op.
- status_message() str[source]#
Return a string generated from dictionary of current metrics, including all the static metrics and moving average metrics.
- Returns:
- str
The status message.
- to_device(data: Iterable, device: str | None = None) Iterable[source]#
Moves the data to the specified device.
- Parameters:
- dataIterable
The data to move to the device.
- devicety.Optional[str]
The device to move the data to. If
None, the device specified in the config is used.
- Returns:
- Iterable
The data on the device.
- property total_steps: int#
The total number of steps for training.
- Returns:
- int
total steps = epoch’s length * number of epochs.
- final train(run_config: RunConfig | None = None, smoke_test: bool = False, debug: bool = False, resume: bool = False) dict[str, float][source]#
Initialize states and train the model. When keyboard interrupts, saves a checkpoint
- Parameters:
- run_configRunConfig | None
The run config to use for training, by default
None.- smoke_testbool
Whether to run a smoke test, by default
False.- debugbool
Whether to run in debug mode, by default
False.- resumebool
Whether to resume training the model from existing checkpoints and existing experiment state, by default
False.- Returns
- ——-
- dict[str, float]
The metrics from the training.
- Raises:
- ValueError
if the state is not initialized and no
run_configis provided or when arun_configis provided but the state is already initialized.
- property train_config: TrainConfig#
- train_loop(smoke_test: bool = False) dict[str, float][source]#
Train the model in many steps, evaluate the model and log the metrics for each iteration. metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.
- Parameters:
- smoke_testbool
Whether to run a smoke test.
- Returns:
- dict[str, float]
metrics after completion of training.
- Raises:
- ValueError
If model outputs are not stored in correct format.
- LossDivergedError
If the loss diverged.
- final train_step(batch: Iterable) tuple[dict[str, torch.Tensor] | None, dict[str, Any]][source]#
A single step for training. It also updates learning rate with scheduler.
- Parameters:
- batchIterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- tuple[dict[str, torch.Tensor] | None, dict[str, ty.Any]]
- outputsdict[str, torch.Tensor] | None
The output of the model.
- train_metrics: dict[str, ty.Any]
The training metrics.
- update_status() None[source]#
Update the metrics with current training stats, and then all metrics (static and moving average) will be set as description for the
tqdmprogress.
- validation_loop(model: Module, dataloader: DataLoader, metrics: Metrics, subsample: float = 1.0, smoke_test: bool = False) dict[str, float][source]#
Validate the model on data on dataloader and store results on metrics.
- Parameters:
- modelnn.Module
The model to validate.
- dataloaderDataLoader
The dataloader to use for validation.
- metricsMetrics
The metrics to use for validation.
- subsamplefloat
The fraction of the dataloader to use for validation, by default
1.0.- smoke_testbool
Whether to execute this function as a smoke test. If
True, only one iteration will be performed, which is useful for quickly checking if the code runs without errors, by defaultFalse.
- Returns:
- dict[str, float]
The metrics from the validation.