Prototype Trainer#

class ablator.main.proto.ProtoTrainer(wrapper: ModelWrapper, run_config: RunConfig)[source]

Bases: object

Manages resources for Prototyping. This trainer runs experiment of a single prototype model. (Therefore no HPO)

Raises:
RuntimeError

If experiment directory is not defined in the running configuration.

Examples

Below is a complete workflow on how to launch a prototype experiment with ProtoTrainer, from defining config to launching the experiment:

  • Define training config:

>>> my_optim_config = OptimizerConfig("sgd", {"lr": 0.5, "weight_decay": 0.5})
>>> my_scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99})
>>> train_config = TrainConfig(
...     dataset="[Dataset Name]",
...     batch_size=32,
...     epochs=10,
...     optimizer_config = my_optimizer_config,
...     scheduler_config = my_scheduler_config,
...     rand_weights_init = True
... )
  • Define model config, here we use default one with no custom hyperparameters (sometimes you would want to define model config when running HPO on your model’s hyperparameters in the parallel experiments with `ParallelTrainer`, which requires `ParallelConfig` instead of `RunConfig`):

>>> model_config = ModelConfig()
  • Define run config:

>>> run_config = CustomRunConfig(
...     train_config=train_config,
...     model_config=model_config,
...     metrics_n_batches = 800,
...     experiment_dir = "/tmp/experiments",
...     device="cpu",
...     amp=False,
...     random_seed = 42
... )
  • Create model wrapper:

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>
>>>     def make_dataloader_train(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(<train_dataset>, batch_size=32, shuffle=True)
>>>
>>>     def make_dataloader_val(self, run_config: CustomRunConfig):
>>>         return torch.utils.data.DataLoader(<val_dataset>, batch_size=32, shuffle=False)
  • After gathering all configurations and model wrapper, it’s time we initialize and launch the prototype trainer:

>>> wrapper = MyModelWrapper(
...     model_class=<your_ModelModule_class>,
... )
>>> ablator = ProtoTrainer(
...     wrapper=wrapper,
...     run_config=run_config,
... )
>>> metrics = ablator.launch()
Attributes:
wrapperModelWrapper

The main model wrapper.

run_configRunConfig

Running configuration for the model.

launch(debug: bool = False)[source]

Launch the prototype experiment (train, evaluate the single prototype model) and return metrics.

Parameters:
debugbool, default=False

Whether to train model in debug mode.

Returns:
metricsMetrics

Metrics returned after training.