API Reference#
This is the class and function reference of ablator. Please refer to the tutorials for further details, as the class and function raw specifications may not be enough to give full guidelines on their uses.
ablator.analysis: Analysis module#
Plot module#
|
Class for plotting experiment results. |
|
This class is for preparing the results that are associated with each categorical attribute to be studied (e.g., grouping metric results with each of the attributes). |
|
Class for constructing violinplots. |
|
Base class for numerical plots |
|
Class for generating linear plots |
Returns a dictionary mapping input attribute names to output attribute names, with optional remapping based on |
Analysis Main classes#
|
A class that stores and processes the attributes, metrics, and other data for the plotting of the experiment result. |
Analysis Results classes#
|
Class for processing experiment results. |
|
Read the results of an experiment and return them as a pandas |
ablator.config: Config module#
Base Config classes#
|
This class is the building block for all configuration objects within ablator. |
Decorator for |
Prototype Config classes#
|
A base class for model configuration. |
|
The base run configuration that defines the setting of an experiment (experiment main directory, number of checkpoints to maintain, hardware device to use, etc.). |
|
Training configuration that defines the training setting, e.g., batch size, number of epochs, the optimizer to use, etc. |
|
Type of optimization direction. |
Config Type classes#
|
HPO Config classes#
|
Type of search space. |
|
Search space configuration, required in |
|
Subconfiguration for a |
Parallel Config classes#
|
Parallel training configuration, extending from |
|
Type of search algorithm. |
Config Utils functions#
|
Calculates the MD5 hash of one or more dictionaries. |
|
Flattens a nested dictionary, expanding lists and tuples if specified. |
ablator.main: Main module#
HPO Sampler module#
|
GridSampler, expands the grid-space into evenly spaced intervals. |
|
OptunaSampler this class serves as an interface for Optuna based samplers. |
Main Model module#
|
Base class that removes training boiler-plate code with extensible support for multiple use-cases. |
|
A wrapper around |
Experiment State module#
|
Initializes the ExperimentState. |
|
An enumeration of possible states for a trial with more pruned states. |
Prototype Trainer classes#
|
Manages resources for Prototyping. |
Multi-process Trainer classes#
|
A class for parallelizing multiple training processes of models of different configurations with ray. |
ablator.modules: Extra modules#
Metrics module#
|
Stores and manages predictions and calculates metrics given some custom evaluation functions. |
Base class for manipulations (storing, getting, resetting) of batches of values. |
|
This class is used to store moving average metrics |
|
A class for storing prediction scores. |
Storage module#
|
Configuration for a remote storage. |
|
Run a command and wait for it to finish. |
Optimizer classes#
|
A base class for optimizer arguments, here we define learning rate lr. |
|
Configuration for an optimizer, including optimizer name and arguments (these arguments are specific to a certain type of optimizer like SGD, Adam, AdamW). |
|
Configuration for an SGD optimizer. |
|
Configuration for an AdamW optimizer. |
|
Configuration for an |
Get model parameters to be optimized. |
Scheduler classes#
|
Abstract base class for defining arguments to initialize a learning rate scheduler. |
|
A class that defines a configuration for a learning rate scheduler. |
|
Configuration class for StepLR scheduler. |
|
Configuration class for the OneCycleLR scheduler. |
|
Configuration class for ReduceLROnPlateau scheduler. |
ablator.utils: Utilities#
Base utilities#
|
Convert elements of the input iterable to NumPy arrays if they are torch.Tensor objects. |
|
Moving torch.Tensor elements to the specified device. |
|
Applies a given function |
|
Set the random seed. |
|
Get the learning rate from an optimizer. |
Check if the debugger is currently active. |
|
|
Get a list of all checkpoint files in a directory, sorted from the latest to the earliest. |
|
Parse a device string, an integer, or a list of device strings or integers. |
is_oom_exception checks whether the exception is caused by CUDA out of memory errors. |
|
|
Format number to be no larger than width by converting to scientific notation when the value exceeds width either by informative decimal places or size. |
File utilities#
Remove all but the n latest checkpoints from the given directory. |
|
Converts the input value to a JSON compatible format. |
|
|
Convert a dictionary into a JSON string. |
|
Convert a JSON string into a dictionary. |
|
Create subdirectories under the given parent directory. |
|
Set a value in a nested dictionary. |
|
Save a checkpoint of the given state. |