ablator.main.state package#

Submodules#

ablator.main.state.state module#

class ablator.main.state.state.ExperimentState(experiment_dir: Path, config: ParallelConfig, logger: FileLogger | None = None, resume: bool = False, sampler_seed: int | None = None)[source]#

Bases: object

Initializes the ExperimentState. Initialize databases for storing training states and a sampler Create trials based on total num of trials specified in config

Parameters:
experiment_dirPath

The directory where the experiment data will be stored.

configParallelConfig

The configuration object that defines the experiment settings.

loggerFileLogger | None

The logger for outputting experiment logs. If not specified, a dummy logger will be used, by default None.

resumebool

Whether to resume a previously interrupted experiment, by default False.

sampler_seedint | None

The seed to use for the trial sampler, by default None.

Raises:
RuntimeError

If the specified search_space parameter is not found in the configuration.

AssertionError

If config.search_space is empty.

RuntimeError

if the experiment database already exists and resume is False.

get_trial_configs_by_state(state: TrialState) list[ablator.config.mp.ParallelConfig][source]#

To get all the trial’s configuration in the given state.

Parameters:
stateTrialState

The state of a trial.

Returns:
list[ParallelConfig]

List of configurations of all the trials in that state.

get_trials_by_state(state: TrialState) list[ablator.main.state.store.Trial][source]#

To get all the trials in the given state.

Parameters:
stateTrialState

Represents the state of a trial.

Returns:
list[Trial]

List of all the trials in that given state.

sample_trial() tuple[int, ablator.config.mp.ParallelConfig][source]#

Samples a trial from the search space and persists the trial state to the experiment database.

Returns:
tuple[int, ParallelConfig]

The unique trial_id with respect to the sampler, and the trial configuration.

Raises:
StopIteration

If the number of invalid trials sampled exceeds the internal upper bound (20) or the sampler raises a StopIteration exception indicating that the search space has been exhaustively evaluated.

TypeError

If the trial parameter are invalid and config.ignore_invalid_params is set to False

static search_space_dot_path(trial: ParallelConfig) dict[str, Any][source]#

Returns a dictionary of parameter names and their corresponding values for a given trial.

Parameters:
trialParallelConfig

The trial object to get the search space dot paths from.

Returns:
dict[str, ty.Any]

A dictionary of parameter names and their corresponding values.

Examples

>>> search_space = {"train_config.optimizer_config.arguments.lr":
SearchSpace(value_range=[0, 0.1], value_type="float")}
>>> {"train_config.optimizer_config.arguments.lr": 0.1}
static tune_trial_str(trial: ParallelConfig) str[source]#

Generate a string representation of a trial object.

Parameters:
trialParallelConfig

The trial object to generate a string representation for.

Returns:
str

A string representation of the trial object.

update_trial_state(trial_id: int, metrics: dict[str, float] | None = None, state: TrialState = TrialState.RUNNING) None[source]#

Update the state of a trial in both the Experiment database and tell Optuna.

Parameters:
trial_idint

The id of the trial to update.

metricsdict[str, float] | None

The metrics of the trial, by default None.

stateTrialState

The state of the trial, by default TrialState.RUNNING.

Raises:
RuntimeError

if the experiment state is corrupted, i.e repeating trials are found.

Examples

>>> experiment.update_trial_state("fje_2211", {"loss": 0.1}, TrialState.COMPLETED)
valid_trials() list[ablator.main.state.store.Trial][source]#
Returns:
list[Trial]

All the valid trials (the are not pruned [Duplicated or Invalid]).

valid_trials_id() list[int][source]#
Returns:
list[int]

trial ids of all the valid trials.

ablator.main.state.store module#

class ablator.main.state.store.Base(**kwargs: Any)[source]#

Bases: DeclarativeBase

metadata: ClassVar[MetaData] = MetaData()#

Refers to the _schema.MetaData collection that will be used for new _schema.Table objects.

See also

orm_declarative_metadata

registry: ClassVar[_RegistryType] = <sqlalchemy.orm.decl_api.registry object>#

Refers to the _orm.registry in use where new _orm.Mapper objects will be associated.

class ablator.main.state.store.Trial(**kwargs)[source]#

Bases: Base

Class to store adata about trial.

Attributes:
id: Mapped[int]

The trial Id used for internal purposes

config_uid: Mapped[str]

The configuration identifier associated with the trial’s unique attributes

metrics: Mapped[PickleType]

The performance metrics dictionary associated as reported by the trial. Dict[str,float] where str is the metric name and float is the metric value.

config_param: Mapped[PickleType]

The configuration parameters for the specific trial including the defaults.

aug_config_param: Mapped[PickleType]

The augmenting configuration as picked by the config sampler. It is the values only different from the default config (excl. Derived properties)

trial_num: Mapped[Integer]

The trial_num corresponding to the internal HPO sampler, used to communicate with the sampler.

state: Mapped[PickleType]

The TrialState

runtime_errors: Mapped[int]

Total runtime errors that the trial encountered and are incremented every time the trial faces a recoverable error.

aug_config_param: Mapped[PickleType]#
config_param: Mapped[PickleType]#
config_uid: Mapped[str]#
id: Mapped[int]#
metrics: Mapped[PickleType]#
runtime_errors: Mapped[int]#
state: Mapped[PickleType]#
trial_num: Mapped[Integer]#
class ablator.main.state.store.TrialState(value)[source]#

Bases: IntEnum

An enumeration of possible states for a trial with more pruned states.

Attributes:
RUNNINGint

A trial that has been succesfully scheduled to run

COMPLETEint

Succesfully completed trial

PRUNEDint

Trial pruned because of various reasons

FAILint

Trial that produced an unrecoverable error during execution

WAITINGint

Trial that is waiting to be scheduled to run

PRUNED_INVALIDint

Trial that was pruned during sampling as it was invalid

PRUNED_DUPLICATEint

Trial that was sampled but was already present

PRUNED_POOR_PERFORMANCEint

Trial that was pruned during execution for poor performance

FAIL_RECOVERABLEint

Trial that was pruned during execution for poor performance

COMPLETE: int = 1#
FAIL: int = 3#
FAIL_RECOVERABLE: int = 8#
PRUNED: int = 2#
PRUNED_DUPLICATE: int = 6#
PRUNED_INVALID: int = 5#
PRUNED_POOR_PERFORMANCE: int = 7#
RUNNING: int = 0#
WAITING: int = 4#

Module contents#