Source code for ablator.main.hpo.grid

import copy
import logging
import typing as ty

import numpy as np


from ablator.config.hpo import SearchSpace
from ablator.main.hpo.base import BaseSampler


def _parse_search_space(space: SearchSpace) -> list:
    if space.sub_configuration is not None:
        return _expand_search_space(space.sub_configuration.arguments)
    if space.value_range is not None:
        low, high = space.parsed_value_range()
        if space.n_bins is None:
            raise ValueError(f"`n_bins` must be specified for {space}.")
        num = space.n_bins
        dtype = int if space.value_type == "int" else float
        return sorted(set(np.linspace(low, high, num, dtype=dtype)))
    if space.categorical_values is not None:
        return space.categorical_values
    if space.subspaces is not None:
        return [_ for _v in space.subspaces for _ in _parse_search_space(_v)]
    raise ValueError(f"Invalid SearchSpace: {space}")


def _expand_configs(
    configs: list[dict[str, str | int | float | dict]],
    value: dict[str, SearchSpace] | SearchSpace | ty.Any,
    key,
) -> list:
    _configs = []
    if isinstance(value, dict):
        expanded_space = _expand_search_space(value)
    elif isinstance(value, SearchSpace):
        expanded_space = _parse_search_space(value)
    else:
        expanded_space = [value]
    for _config in configs:
        for _v in expanded_space:
            _config[key] = _v
            _configs.append(copy.deepcopy(_config))
    return _configs


def _expand_search_space(
    search_space: dict[str, SearchSpace]
) -> list[dict[str, str | int | float | dict]]:
    configs: list[dict[str, str | int | float | dict]] = [{}]

    for k, v in search_space.items():
        try:
            configs = _expand_configs(configs, v, k)
        except ValueError as e:
            raise ValueError(f"Invalid search space for {k}. {str(e)}") from e
    return configs


[docs]class GridSampler(BaseSampler): """ GridSampler, expands the grid-space into evenly spaced intervals. For example, a search space over ``SearchSpace(value_range=[1,10], n_bins=10)`` will be discritized to 10 intervals [1,..,10]. If the search space is composed of integers, e.g. ``value_type='int'`` the search space will be rounded down via the default python `int()` implementation and only the unique subset will be considered. As a result the discritized search-space can be smaller than n_bins. For example: ``SearchSpace(value_range=[1,5], value_type='int', n_bins=1000)`` will produce a SearchSpace of ``{1,2,3,4,5}``. In contrast, ``SearchSpace(value_range=[1,5], value_type='float', n_bins=1000)`` will produce a SearchSpace of 1000 floats, ``[1. , 1.004004 , 1.00800801, ... , 4.98798799, 4.99199199, 4.995996 , 5.]``. Previous configurations can be supplied via the `configs` argument. If the configurations are not found in the discretized search_space (could be because of numerical stability errors or poor instantiation) they will be stored in memory. Any duplicate configurations will be removed from current sampling memory. Parameters ---------- search_space : dict[str, SearchSpace] A dictionary with keys the configuration name and the search space to sample from configs : list[dict[str, ty.Any]] | None Previous configurations to resume the state from, by default ``None``. seed : int | None A seed to use for the HPO sampler, by default ``None``. """ def __init__( self, search_space: dict[str, SearchSpace], configs: list[dict[str, ty.Any]] | None = None, seed: int | None = None, ) -> None: super().__init__() self.search_space = search_space self.configs = _expand_search_space(search_space) self.sampled_configs = configs if configs is not None else [] for c in self.sampled_configs: if c in self.configs: self.configs.remove(c) else: logging.warning( "Invalid sampled configuration provided to GridSampler. %s", c ) self._lock = False self._rng = np.random.default_rng(seed) # mypy error because of nested dictionary self._rng.shuffle(self.configs) # type: ignore[arg-type] @property def _idx(self): return len(self.sampled_configs) - 1 def _eager_sample(self): if len(self.configs) == 0: # reset self.configs = _expand_search_space(self.search_space) self._rng.shuffle(self.configs) cfg = self.configs.pop() self.sampled_configs.append(cfg) return self._idx, cfg def _drop(self): self.sampled_configs.pop()
[docs] def update_trial(self, trial_id: int, metrics: dict[str, float] | None, state): """ This function is a no-op for grid sampling as it is entirely random. """
[docs] def internal_repr(self, trial_id: int) -> None: """ This function is a no-op for grid sampling as it does not need a reason to maintain an internal representation of trials. """ return None