Prototype Trainer#
- class ablator.main.proto.ProtoTrainer(wrapper: ModelWrapper, run_config: RunConfig)[source]
Manages resources for Prototyping. This trainer runs an experiment of a single prototype model (Therefore no ablation study nor HPO).
- Parameters:
- wrapperModelWrapper
The main model wrapper.
- run_configRunConfig
Running configuration for the model.
- Raises:
- RuntimeError
If the experiment directory is not defined in the running configuration.
Examples
Below is a complete workflow on how to launch a prototype experiment with
ProtoTrainer, from defining the config to launching the experiment:Define training config:
>>> my_optimizer_config = OptimizerConfig("sgd", {"lr": 0.5, "weight_decay": 0.5}) >>> my_scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) >>> train_config = TrainConfig( ... dataset="[Dataset Name]", ... batch_size=32, ... epochs=10, ... optimizer_config = my_optimizer_config, ... scheduler_config = my_scheduler_config ... )
Define model config: we use the default one with no custom hyperparameters (sometimes you would want to customize it to run ablation study/ HPO on the model’s hyperparameters in a parallel experiment, which needs
ParallelTrainerandParallelConfiginstead ofProtoTrainerandRunConfig):
>>> model_config = ModelConfig()
Define run config:
>>> run_config = RunConfig( ... train_config=train_config, ... model_config=model_config, ... metrics_n_batches = 800, ... experiment_dir = "/tmp/experiments", ... device="cpu", ... amp=False, ... random_seed = 42 ... )
Create model wrapper:
>>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> >>> def make_dataloader_train(self, run_config: RunConfig): >>> return torch.utils.data.DataLoader(<train_dataset>, batch_size=32, shuffle=True) >>> >>> def make_dataloader_val(self, run_config: RunConfig): >>> return torch.utils.data.DataLoader(<val_dataset>, batch_size=32, shuffle=False)
After gathering all configurations and model wrapper, it’s time we initialize and launch the prototype trainer. When launching the experiment, we must provide a working directory, which points to a git repository that is used for keeping track of the code differences:
>>> wrapper = MyModelWrapper( ... model_class=<your_ModelModule_class>, ... ) >>> ablator = ProtoTrainer( ... wrapper=wrapper, ... run_config=run_config, ... ) >>> metrics = ablator.launch(working_directory=os.getcwd()) # suppose current directory is tracked by git
- Attributes:
- wrapperModelWrapper
The main model wrapper.
- run_configRunConfig
Running configuration for the model.
- experiment_dirPath
The path object to the experiment directory.
- launch(working_directory: str, debug: bool = False) dict[str, float][source]
Launch the prototype experiment (train, evaluate the single prototype model) and return metrics.
- Parameters:
- working_directorystr
The working directory points to a git repository that is used for keeping track of the code differences.
- debugbool, optional
Whether to train models in debug mode, by default
False.
- Returns:
- metricsdict[str, float]
Metrics returned after training.