ablator.main.state package#

Submodules#

ablator.main.state.state module#

class ablator.main.state.state.ExperimentState(experiment_dir: Path, config: ParallelConfig, logger: FileLogger | None = None, resume: bool = False, sampler_seed: int | None = None)[source]#

Bases: object

get_trial_configs_by_state(state: TrialState) list[ablator.config.mp.ParallelConfig][source]#
get_trials_by_state(state: TrialState) list[ablator.main.state.store.Trial][source]#
sample_trial() tuple[int, ablator.config.mp.ParallelConfig][source]#

Samples a trial from the search space and persists the trial state to the experiment database.

Returns:
tuple[int, ParallelConfig]

The unique trial_id with respect to the sampler, and the trial configuration.

Raises:
StopIteration

If the number of invalid trials sampled exceeds the internal upper bound (20) or the sampler raises a StopIteration exception indicating that the search space has been exhaustively evaluated.

TypeError

If the trial parameter are invalid and config.ignore_invalid_params is set to False

static search_space_dot_path(trial: ParallelConfig) dict[str, Any][source]#

Returns a dictionary of parameter names and their corresponding values for a given trial.

Parameters:
trialParallelConfig

The trial object to get the search space dot paths from.

Returns:
dict[str, Any]

A dictionary of parameter names and their corresponding values.

Examples

>>> search_space = {"train_config.optimizer_config.arguments.lr":
SearchSpace(value_range=[0, 0.1], value_type="float")}
>>> {"train_config.optimizer_config.arguments.lr": 0.1}
static tune_trial_str(trial: ParallelConfig) str[source]#

Generate a string representation of a trial object.

Parameters:
trialParallelConfig

The trial object to generate a string representation for.

Returns:
str

A string representation of the trial object.

update_trial_state(trial_id: int, metrics: dict[str, float] | None = None, state: TrialState = TrialState.RUNNING) None[source]#

Update the state of a trial in both the Experiment database and tell Optuna.

Parameters:
trial_idint

The id of the trial to update.

metricsdict[str, float] | None, optional

The metrics of the trial, by default None.

stateTrialState, optional

The state of the trial, by default TrialState.RUNNING.

Examples

>>> experiment.update_trial_state("fje_2211", {"loss": 0.1}, TrialState.COMPLETED)
valid_trials() list[ablator.main.state.store.Trial][source]#
valid_trials_id() list[int][source]#

ablator.main.state.store module#

class ablator.main.state.store.Base(**kwargs: Any)[source]#

Bases: DeclarativeBase

metadata: ClassVar[MetaData] = MetaData()#

Refers to the _schema.MetaData collection that will be used for new _schema.Table objects.

See also

orm_declarative_metadata

registry: ClassVar[_RegistryType] = <sqlalchemy.orm.decl_api.registry object>#

Refers to the _orm.registry in use where new _orm.Mapper objects will be associated.

class ablator.main.state.store.Trial(**kwargs)[source]#

Bases: Base

aug_config_param: Mapped[PickleType]#
config_param: Mapped[PickleType]#
config_uid: Mapped[str]#
id: Mapped[int]#
metrics: Mapped[PickleType]#
runtime_errors: Mapped[int]#
state: Mapped[PickleType]#
trial_num: Mapped[Integer]#
class ablator.main.state.store.TrialState(value)[source]#

Bases: IntEnum

An enumeration of possible states for a trial with more pruned states.

Attributes:
RUNNINGint

A trial that has been succesfully scheduled to run

COMPLETEint

Succesfully completed trial

PRUNEDint

Trial pruned because of various reasons

FAILint

Trial that produced an unrecoverable error during execution

WAITINGint

Trial that is waiting to be scheduled to run

PRUNED_INVALIDint

Trial that was pruned during sampling as it was invalid

PRUNED_DUPLICATEint

Trial that was sampled but was already present

PRUNED_POOR_PERFORMANCEint

Trial that was pruned during execution for poor performance

FAIL_RECOVERABLEint

Trial that was pruned during execution for poor performance

COMPLETE = 1#
FAIL = 3#
FAIL_RECOVERABLE = 8#
PRUNED = 2#
PRUNED_DUPLICATE = 6#
PRUNED_INVALID = 5#
PRUNED_POOR_PERFORMANCE = 7#
RUNNING = 0#
WAITING = 4#

Module contents#