ablator.main.state package#
Submodules#
ablator.main.state.state module#
- class ablator.main.state.state.ExperimentState(experiment_dir: Path, config: ParallelConfig, logger: FileLogger | None = None, resume: bool = False, sampler_seed: int | None = None)[source]#
Bases:
objectInitializes the ExperimentState. Initialize databases for storing training states and a sampler Create trials based on total num of trials specified in config
- Parameters:
- experiment_dirPath
The directory where the experiment data will be stored.
- configParallelConfig
The configuration object that defines the experiment settings.
- loggerFileLogger | None
The logger for outputting experiment logs. If not specified, a dummy logger will be used, by default
None.- resumebool
Whether to resume a previously interrupted experiment, by default
False.- sampler_seedint | None
The seed to use for the trial sampler, by default
None.
- Raises:
- RuntimeError
If the specified
search_spaceparameter is not found in the configuration.- AssertionError
If
config.search_spaceis empty.- RuntimeError
if the experiment database already exists and
resumeisFalse.
- get_trial_configs_by_state(state: TrialState) list[ablator.config.mp.ParallelConfig][source]#
To get all the trial’s configuration in the given state.
- Parameters:
- stateTrialState
The state of a trial.
- Returns:
- list[ParallelConfig]
List of configurations of all the trials in that state.
- get_trials_by_state(state: TrialState) list[ablator.main.state.store.Trial][source]#
To get all the trials in the given state.
- Parameters:
- stateTrialState
Represents the state of a trial.
- Returns:
- list[Trial]
List of all the trials in that given state.
- sample_trial() tuple[int, ablator.config.mp.ParallelConfig][source]#
Samples a trial from the search space and persists the trial state to the experiment database.
- Returns:
- tuple[int, ParallelConfig]
The unique trial_id with respect to the sampler, and the trial configuration.
- Raises:
- StopIteration
If the number of invalid trials sampled exceeds the internal upper bound (20) or the sampler raises a StopIteration exception indicating that the search space has been exhaustively evaluated.
- TypeError
If the trial parameter are invalid and config.ignore_invalid_params is set to False
- static search_space_dot_path(trial: ParallelConfig) dict[str, Any][source]#
Returns a dictionary of parameter names and their corresponding values for a given trial.
- Parameters:
- trialParallelConfig
The trial object to get the search space dot paths from.
- Returns:
- dict[str, ty.Any]
A dictionary of parameter names and their corresponding values.
Examples
>>> search_space = {"train_config.optimizer_config.arguments.lr": SearchSpace(value_range=[0, 0.1], value_type="float")} >>> {"train_config.optimizer_config.arguments.lr": 0.1}
- static tune_trial_str(trial: ParallelConfig) str[source]#
Generate a string representation of a trial object.
- Parameters:
- trialParallelConfig
The trial object to generate a string representation for.
- Returns:
- str
A string representation of the trial object.
- update_trial_state(trial_id: int, metrics: dict[str, float] | None = None, state: TrialState = TrialState.RUNNING) None[source]#
Update the state of a trial in both the Experiment database and tell Optuna.
- Parameters:
- trial_idint
The id of the trial to update.
- metricsdict[str, float] | None
The metrics of the trial, by default
None.- stateTrialState
The state of the trial, by default
TrialState.RUNNING.
- Raises:
- RuntimeError
if the experiment state is corrupted, i.e repeating trials are found.
Examples
>>> experiment.update_trial_state("fje_2211", {"loss": 0.1}, TrialState.COMPLETED)
- valid_trials() list[ablator.main.state.store.Trial][source]#
- Returns:
- list[Trial]
All the valid trials (the are not pruned [Duplicated or Invalid]).
ablator.main.state.store module#
- class ablator.main.state.store.Base(**kwargs: Any)[source]#
Bases:
DeclarativeBase- metadata: ClassVar[MetaData] = MetaData()#
Refers to the
_schema.MetaDatacollection that will be used for new_schema.Tableobjects.See also
orm_declarative_metadata
- registry: ClassVar[_RegistryType] = <sqlalchemy.orm.decl_api.registry object>#
Refers to the
_orm.registryin use where new_orm.Mapperobjects will be associated.
- class ablator.main.state.store.Trial(**kwargs)[source]#
Bases:
BaseClass to store adata about trial.
- Attributes:
- id: Mapped[int]
The trial Id used for internal purposes
- config_uid: Mapped[str]
The configuration identifier associated with the trial’s unique attributes
- metrics: Mapped[PickleType]
The performance metrics dictionary associated as reported by the trial. Dict[str,float] where str is the metric name and float is the metric value.
- config_param: Mapped[PickleType]
The configuration parameters for the specific trial including the defaults.
- aug_config_param: Mapped[PickleType]
The augmenting configuration as picked by the config sampler. It is the values only different from the default config (excl. Derived properties)
- trial_num: Mapped[Integer]
The trial_num corresponding to the internal HPO sampler, used to communicate with the sampler.
- state: Mapped[PickleType]
The
TrialState- runtime_errors: Mapped[int]
Total runtime errors that the trial encountered and are incremented every time the trial faces a recoverable error.
- aug_config_param: Mapped[PickleType]#
- config_param: Mapped[PickleType]#
- config_uid: Mapped[str]#
- id: Mapped[int]#
- metrics: Mapped[PickleType]#
- runtime_errors: Mapped[int]#
- state: Mapped[PickleType]#
- trial_num: Mapped[Integer]#
- class ablator.main.state.store.TrialState(value)[source]#
Bases:
IntEnumAn enumeration of possible states for a trial with more pruned states.
- Attributes:
- RUNNINGint
A trial that has been succesfully scheduled to run
- COMPLETEint
Succesfully completed trial
- PRUNEDint
Trial pruned because of various reasons
- FAILint
Trial that produced an unrecoverable error during execution
- WAITINGint
Trial that is waiting to be scheduled to run
- PRUNED_INVALIDint
Trial that was pruned during sampling as it was invalid
- PRUNED_DUPLICATEint
Trial that was sampled but was already present
- PRUNED_POOR_PERFORMANCEint
Trial that was pruned during execution for poor performance
- FAIL_RECOVERABLEint
Trial that was pruned during execution for poor performance
- COMPLETE: int = 1#
- FAIL: int = 3#
- FAIL_RECOVERABLE: int = 8#
- PRUNED: int = 2#
- PRUNED_DUPLICATE: int = 6#
- PRUNED_INVALID: int = 5#
- PRUNED_POOR_PERFORMANCE: int = 7#
- RUNNING: int = 0#
- WAITING: int = 4#