Source code for ablator.main.hpo.base

from abc import ABC, abstractmethod
import typing as ty

from ablator.main.state.store import TrialState


[docs]class BaseSampler(ABC): def __init__(self) -> None: super().__init__() self._lock = False @abstractmethod def _eager_sample(self): """ Please see documentation of `eager_sampler` """ raise NotImplementedError
[docs] @abstractmethod def update_trial( self, trial_id: int, metrics: dict[str, float] | None, state: TrialState ): """ Update the trial state given the trial_id, the updated metrics, and the current trial state. Parameters ---------- trial_id : int the trial_id which was returned when running ``eager_sampler`` metrics : dict[str, float] a metric dictionary corresponding to the updated metrics. state : TrialState the updated trial state """ raise NotImplementedError
[docs] @abstractmethod def internal_repr(self, trial_id: int) -> None | dict[str, ty.Any]: """ Return the internal representation of the trial if one is maintained by the sampler. Parameters ---------- trial_id : int the trial_id which internal representation is retrieved. Returns ------- None | dict[str, ty.Any] ``None`` when there is no internal representation maintained by the sampler. Otherwise a dictionary with keys as the internal configuration names and values, the corresponding values. """ raise NotImplementedError
@abstractmethod def _drop(self): """ Should delete an eagerly sampled trial. Please see documentation of `unlock` for implimenetation details. """ raise NotImplementedError
[docs] def eager_sample(self) -> tuple[int, dict[str, ty.Any], None | dict[str, ty.Any]]: """ eager_sample A sampled trial can be erroneous, for this reason we eagerly sample and lock the sampler until the user can verify the sampled configuration. Returns ------- tuple[int | dict[str, ty.Any]] a tuple that contains the trial id and the sampled configuration from the search space. Raises ------ StopIteration Can raise an error if there are no more trials to sample. """ assert ( not self._lock ), "Must call `unlock(drop=[True,False])` after `eager_sampler`." self._lock = True trial_id, config = self._eager_sample() kwargs = self.internal_repr(trial_id) return trial_id, config, kwargs
[docs] def unlock(self, drop: bool): """ unlock informs the sampler on whether the eagerely sampled trial was valid or should be dropped. Parameters ---------- drop : bool whether to drop the trial """ assert self._lock or not drop self._lock = False if drop: self._drop()