ablator.main.hpo package#

Submodules#

ablator.main.hpo.base module#

class ablator.main.hpo.base.BaseSampler[source]#

Bases: ABC

eager_sample() tuple[int, dict[str, Any], None | dict[str, Any]][source]#

eager_sample A sampled trial can be erroneous, for this reason we eagerly sample and lock the sampler until the user can verify the sampled configuration.

Returns:
tuple[int, dict[str, ty.Any], None | dict[str, ty.Any]]

a tuple that contains the trial id and the sampled configuration from the search space.

Raises:
StopIteration

Can raise an error if there are no more trials to sample.

abstract internal_repr(trial_id: int) None | dict[str, Any][source]#

Return the internal representation of the trial if one is maintained by the sampler.

Parameters:
trial_idint

the trial_id which internal representation is retrieved.

Returns:
None | dict[str, ty.Any]

None when there is no internal representation maintained by the sampler. Otherwise a dictionary with keys as the internal configuration names and values, the corresponding values.

Raises:
NotImplementedError

If the method is not implemented by the subclasses.

unlock(drop: bool)[source]#

unlock informs the sampler on whether the eagerely sampled trial was valid or should be dropped.

Parameters:
dropbool

whether to drop the trial

abstract update_trial(trial_id: int, metrics: OrderedDict[str, float] | None, state: TrialState)[source]#

Update the trial state given the trial_id, the updated metrics, and the current trial state.

Parameters:
trial_idint

the trial_id which was returned when running eager_sampler

metricsOrderedDict[str, float] | None

a metric dictionary corresponding to the updated metrics.

stateTrialState

the updated trial state

Raises:
NotImplementedError

If the method is not implemented by the subclasses.

ablator.main.hpo.grid module#

class ablator.main.hpo.grid.GridSampler(search_space: dict[str, ablator.config.hpo.SearchSpace], configs: list[dict[str, Any]] | None = None, seed: int | None = None)[source]#

Bases: BaseSampler

GridSampler, expands the grid-space into evenly spaced intervals. For example, a search space over SearchSpace(value_range=[1,10], n_bins=10) will be discritized to 10 intervals [1,..,10]. If the search space is composed of integers, e.g. value_type='int' the search space will be rounded down via the default python int() implementation and only the unique subset will be considered. As a result the discritized search-space can be smaller than n_bins. For example: SearchSpace(value_range=[1,5], value_type='int', n_bins=1000) will produce a SearchSpace of {1,2,3,4,5}. In contrast, SearchSpace(value_range=[1,5], value_type='float', n_bins=1000) will produce a SearchSpace of 1000 floats, [1. , 1.004004  , 1.00800801, ... , 4.98798799, 4.99199199, 4.995996  , 5.].

Previous configurations can be supplied via the configs argument. If the configurations are not found in the discretized search_space (could be because of numerical stability errors or poor instantiation) they will be stored in memory. Any duplicate configurations will be removed from current sampling memory.

Parameters:
search_spacedict[str, SearchSpace]

A dictionary with keys the configuration name and the search space to sample from

configslist[dict[str, ty.Any]] | None

Previous configurations to resume the state from, by default None.

seedint | None

A seed to use for the HPO sampler, by default None.

internal_repr(trial_id: int) None[source]#

This function is a no-op for grid sampling as it does not need a reason to maintain an internal representation of trials.

update_trial(trial_id: int, metrics: dict[str, float] | None, state)[source]#

This function is a no-op for grid sampling as it is entirely random.

ablator.main.hpo.optuna module#

class ablator.main.hpo.optuna.OptunaSampler(search_algo: SearchAlgo, search_space: dict[str, ablator.config.hpo.SearchSpace], optim_metrics: OrderedDict[str, Optim], trials: list['_state_store.Trial'] | None = None, seed: int | None = None)[source]#

OptunaSampler this class serves as an interface for Optuna based samplers. WARNING: The class will be refactored in future versions and should not be used by library users. The class is meant for internal use only.

Module contents#