Source code for ablator.main.state.store
import enum
from sqlalchemy import Integer, PickleType, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
[docs]class Base(DeclarativeBase):
pass
[docs]class TrialState(enum.IntEnum):
"""
An enumeration of possible states for a trial with more pruned states.
Attributes
----------
RUNNING : int
A trial that has been succesfully scheduled to run
COMPLETE : int
Succesfully completed trial
PRUNED : int
Trial pruned because of various reasons
FAIL : int
Trial that produced an unrecoverable error during execution
WAITING : int
Trial that is waiting to be scheduled to run
PRUNED_INVALID : int
Trial that was pruned during sampling as it was invalid
PRUNED_DUPLICATE : int
Trial that was sampled but was already present
PRUNED_POOR_PERFORMANCE : int
Trial that was pruned during execution for poor performance
FAIL_RECOVERABLE : int
Trial that was pruned during execution for poor performance
"""
RUNNING: int = 0
COMPLETE: int = 1
PRUNED: int = 2
FAIL: int = 3
WAITING: int = 4
PRUNED_INVALID: int = 5
PRUNED_DUPLICATE: int = 6
PRUNED_POOR_PERFORMANCE: int = 7
FAIL_RECOVERABLE: int = 8
[docs]class Trial(Base):
"""
Class to store adata about trial.
Attributes
----------
id: Mapped[int]
The trial Id used for internal purposes
config_uid: Mapped[str]
The configuration identifier associated with the trial's unique attributes
metrics: Mapped[PickleType]
The performance metrics dictionary associated as reported by the trial.
Dict[str,float] where str is the metric name and float is the metric value.
config_param: Mapped[PickleType]
The configuration parameters for the specific trial including the defaults.
aug_config_param: Mapped[PickleType]
The augmenting configuration as picked by the config sampler.
It is the values only different from the default config (excl. Derived properties)
trial_num: Mapped[Integer]
The trial_num corresponding to the internal HPO sampler, used to communicate with the sampler.
state: Mapped[PickleType]
The ``TrialState``
runtime_errors: Mapped[int]
Total runtime errors that the trial encountered and are incremented
every time the trial faces a recoverable error.
"""
__tablename__ = "trial"
id: Mapped[int] = mapped_column(primary_key=True)
config_uid: Mapped[str] = mapped_column(String(30))
metrics: Mapped[PickleType] = mapped_column(PickleType)
config_param: Mapped[PickleType] = mapped_column(PickleType)
aug_config_param: Mapped[PickleType] = mapped_column(PickleType)
trial_num: Mapped[Integer] = mapped_column(Integer)
state: Mapped[PickleType] = mapped_column(PickleType, default=TrialState.WAITING)
runtime_errors: Mapped[int] = mapped_column(Integer, default=0)
# NOTE the following attributes are subject to be removed when optuna is decoupled.
# they are for internal use ONLY
_opt_distributions_kwargs: Mapped[PickleType] = mapped_column(
PickleType, nullable=True
)
_opt_distributions_types: Mapped[PickleType] = mapped_column(
PickleType, nullable=True
)
_opt_params: Mapped[PickleType] = mapped_column(PickleType, nullable=True)
def __repr__(self) -> str:
return f"Trial(id={self.id!r}, config_uid={self.config_uid!r}, fullname={self.aug_config_param!r})"