ablator.main.hpo package#

Submodules#

ablator.main.hpo.base module#

class ablator.main.hpo.base.BaseSampler[source]#

Bases: ABC

eager_sample() tuple[int, dict[str, Any], None | dict[str, Any]][source]#

eager_sample A sampled trial can be erroneous, for this reason we eagerly sample and lock the sampler until the user can verify the sampled configuration.

Returns:
tuple[int | dict[str, ty.Any]]

a tuple that contains the trial id and the sampled configuration from the search space.

Raises:
StopIteration

Can raise an error if there are no more trials to sample.

abstract internal_repr(trial_id: int) None | dict[str, Any][source]#

Return the internal representation of the trial if one is maintained by the sampler.

Parameters:
trial_idint

the trial_id which internal representation is retrieved.

Returns:
None | dict[str, ty.Any]

None when there is no internal representation maintained by the sampler. Otherwise a dictionary with keys as the internal configuration names and values, the corresponding values.

unlock(drop: bool)[source]#

unlock informs the sampler on whether the eagerely sampled trial was valid or should be dropped.

Parameters:
dropbool

whether to drop the trial

abstract update_trial(trial_id: int, metrics: dict[str, float] | None, state: TrialState)[source]#

Update the trial state given the trial_id, the updated metrics, and the current trial state.

Parameters:
trial_idint

the trial_id which was returned when running eager_sampler

metricsdict[str, float]

a metric dictionary corresponding to the updated metrics.

stateTrialState

the updated trial state

ablator.main.hpo.grid module#

class ablator.main.hpo.grid.GridSampler(search_space: dict[str, ablator.config.hpo.SearchSpace], configs: list[dict[str, Any]] | None = None, seed: int | None = None)[source]#

Bases: BaseSampler

internal_repr(trial_id)[source]#

This function is a no-op for grid sampling as it does not need a reason to maintain an internal representation of trials.

update_trial(trial_id, metrics: dict[str, float] | None, state)[source]#

This function is a no-op for grid sampling as it is entirely random.

ablator.main.hpo.optuna module#

class ablator.main.hpo.optuna.OptunaSampler(search_algo: SearchAlgo, search_space: dict[str, ablator.config.hpo.SearchSpace], optim_metrics: OrderedDict[str, Optim], trials: list['_state_store.Trial'] | None = None, seed: int | None = None)[source]#

OptunaSampler this class serves as an interface for Optuna based samplers. WARNING: The class will be refactored in future versions and should not be used by library users. The class is meant for internal use only.

Module contents#