ablator.main.state package#
Submodules#
ablator.main.state.state module#
- class ablator.main.state.state.ExperimentState(experiment_dir: Path, config: ParallelConfig, logger: FileLogger | None = None, resume: bool = False, sampler_seed: int | None = None)[source]#
Bases:
object- get_trial_configs_by_state(state: TrialState) list[ablator.config.mp.ParallelConfig][source]#
- get_trials_by_state(state: TrialState) list[ablator.main.state.store.Trial][source]#
- sample_trial() tuple[int, ablator.config.mp.ParallelConfig][source]#
Samples a trial from the search space and persists the trial state to the experiment database.
- Returns:
- tuple[int, ParallelConfig]
The unique trial_id with respect to the sampler, and the trial configuration.
- Raises:
- StopIteration
If the number of invalid trials sampled exceeds the internal upper bound (20) or the sampler raises a StopIteration exception indicating that the search space has been exhaustively evaluated.
- TypeError
If the trial parameter are invalid and config.ignore_invalid_params is set to False
- static search_space_dot_path(trial: ParallelConfig) dict[str, Any][source]#
Returns a dictionary of parameter names and their corresponding values for a given trial.
- Parameters:
- trialParallelConfig
The trial object to get the search space dot paths from.
- Returns:
- dict[str, Any]
A dictionary of parameter names and their corresponding values.
Examples
>>> search_space = {"train_config.optimizer_config.arguments.lr": SearchSpace(value_range=[0, 0.1], value_type="float")} >>> {"train_config.optimizer_config.arguments.lr": 0.1}
- static tune_trial_str(trial: ParallelConfig) str[source]#
Generate a string representation of a trial object.
- Parameters:
- trialParallelConfig
The trial object to generate a string representation for.
- Returns:
- str
A string representation of the trial object.
- update_trial_state(trial_id: int, metrics: dict[str, float] | None = None, state: TrialState = TrialState.RUNNING) None[source]#
Update the state of a trial in both the Experiment database and tell Optuna.
- Parameters:
- trial_idint
The id of the trial to update.
- metricsdict[str, float] | None, optional
The metrics of the trial, by default
None.- stateTrialState, optional
The state of the trial, by default
TrialState.RUNNING.
Examples
>>> experiment.update_trial_state("fje_2211", {"loss": 0.1}, TrialState.COMPLETED)
- valid_trials() list[ablator.main.state.store.Trial][source]#
ablator.main.state.store module#
- class ablator.main.state.store.Base(**kwargs: Any)[source]#
Bases:
DeclarativeBase- metadata: ClassVar[MetaData] = MetaData()#
Refers to the
_schema.MetaDatacollection that will be used for new_schema.Tableobjects.See also
orm_declarative_metadata
- registry: ClassVar[_RegistryType] = <sqlalchemy.orm.decl_api.registry object>#
Refers to the
_orm.registryin use where new_orm.Mapperobjects will be associated.
- class ablator.main.state.store.Trial(**kwargs)[source]#
Bases:
Base- aug_config_param: Mapped[PickleType]#
- config_param: Mapped[PickleType]#
- config_uid: Mapped[str]#
- id: Mapped[int]#
- metrics: Mapped[PickleType]#
- runtime_errors: Mapped[int]#
- state: Mapped[PickleType]#
- trial_num: Mapped[Integer]#
- class ablator.main.state.store.TrialState(value)[source]#
Bases:
IntEnumAn enumeration of possible states for a trial with more pruned states.
- Attributes:
- RUNNINGint
A trial that has been succesfully scheduled to run
- COMPLETEint
Succesfully completed trial
- PRUNEDint
Trial pruned because of various reasons
- FAILint
Trial that produced an unrecoverable error during execution
- WAITINGint
Trial that is waiting to be scheduled to run
- PRUNED_INVALIDint
Trial that was pruned during sampling as it was invalid
- PRUNED_DUPLICATEint
Trial that was sampled but was already present
- PRUNED_POOR_PERFORMANCEint
Trial that was pruned during execution for poor performance
- FAIL_RECOVERABLEint
Trial that was pruned during execution for poor performance
- COMPLETE = 1#
- FAIL = 3#
- FAIL_RECOVERABLE = 8#
- PRUNED = 2#
- PRUNED_DUPLICATE = 6#
- PRUNED_INVALID = 5#
- PRUNED_POOR_PERFORMANCE = 7#
- RUNNING = 0#
- WAITING = 4#