Source code for ablator.main.hpo.optuna

# type: ignore
# pylint: skip-file
"""
TODO current implementation is meant to be temporary until there is a concrete replacement to
Optuna or better integeration. The problems with using optuna are several:
1. Optuna does not support conditional search spaces
    for example if config-01.a=1 was sampled from [config-01,config-02] it is not taken into account when sampling
    config.a

2. Optuna internal TrialStates are limited, there are few states, and transition
between the states can cause error. For example we can not report metrics None
for TrialState.COMPLETE which is required when there are no optim_metrics
for a given sampling strategy.

3. Resuming the Sampler is problematic. As we have to now match the
internal experiment state to that of Optuna sampler.

4. Obscure implementation details. For example, it is unclear the benefit the distincition between
`indepedent sampling` and `relative sampling`. Additional implementation nuances can be
seen by inspecting the code, like if a parameter is already sampled for a given trial,
return that parameter, which is error prone as we might need to for example re-sample a
parameter in case of an error.

5. Removing trials in case of errors or issues is not possible. For example once a trial is sampled, it is stored
in the internal state. If the sampled configuration is invalid for whatever reason we do not want to store it.

Just to name a few...

"""
import collections
import typing as ty
import warnings
from collections import OrderedDict

import numpy as np
import optuna
from optuna.distributions import (
    BaseDistribution,
    CategoricalDistribution,
    FloatDistribution,
    IntDistribution,
)
from optuna.study._study_direction import StudyDirection
from optuna.trial import TrialState

from ablator.config.hpo import FieldType, SearchSpace
from ablator.config.mp import Optim, SearchAlgo
from ablator.main.state import store as _state_store
from ablator.main.hpo.base import BaseSampler


class _Trial:
    """
    Mock `optuna.Trial` object for the sake of using optuna
    """

    def __init__(
        self,
        id_: int,
        study: "_Study",
        sampler: optuna.samplers.BaseSampler,
        optim_metrics: OrderedDict[str, str],
        resume_trial: ty.Optional["_state_store.Trial"] = None,
    ) -> None:
        self.state = TrialState.RUNNING

        self.values: np.ndarray | None = None
        self.id_ = id_
        self.params: dict[ty.Any, ty.Any] = {}
        self.distributions = {}
        self.optim_metrics = optim_metrics
        if resume_trial is not None:
            self.params = resume_trial._opt_params
            self.distributions = {
                k: eval(resume_trial._opt_distributions_types[k])(
                    **resume_trial._opt_distributions_kwargs[k]
                )
                for k in resume_trial._opt_distributions_kwargs
            }
            metrics = None
            if len(resume_trial.metrics) > 0:
                metrics = resume_trial.metrics[-1]
            self.update(metrics, resume_trial.state)
        self.relative_search_space = sampler.infer_relative_search_space(study, self)
        self.relative_params = sampler.sample_relative(
            study, self, self.relative_search_space
        )

    def update(
        self, metrics: OrderedDict[str, float] | None, state: "_state_store.TrialState"
    ):
        if state == _state_store.TrialState.COMPLETE:
            self.state = TrialState.COMPLETE
        elif state == _state_store.TrialState.FAIL:
            self.state = TrialState.FAIL
        elif state == _state_store.TrialState.RUNNING:
            self.state = TrialState.RUNNING
        else:
            return
        if metrics is not None:
            metric_keys = set(metrics.keys())
            optim_keys = set(self.optim_metrics.keys())
            if metric_keys != optim_keys:
                raise ValueError(
                    f"metric keys {metric_keys} do not match optim_keys {optim_keys}"
                )
            values = [
                metrics[k] if metrics[k] is not None else float("inf")
                for k in self.optim_metrics
            ]
            self.values = values
        else:
            self.values = None

    def is_relative_param(self, name: str, distribution: BaseDistribution) -> bool:
        return self._is_relative_param(name, distribution)

    def _is_relative_param(self, name: str, distribution: BaseDistribution) -> bool:
        if name not in self.relative_params:
            return False

        if name not in self.relative_search_space:
            raise ValueError(
                f"The parameter '{name}' was sampled by `sample_relative` method "
                "but it is not contained in the relative search space."
            )

        param_value = self.relative_params[name]
        param_value_in_internal_repr = distribution.to_internal_repr(param_value)
        return distribution._contains(param_value_in_internal_repr)


# type: ignore
class _Study:
    """
    Mock `optuna.Study` object for the sake of using optuna
    """

    def __init__(
        self,
        optim_metrics: collections.OrderedDict,
        sampler: optuna.samplers.BaseSampler,
        trials: list["_state_store.Trial"] | None = None,
    ) -> None:
        trials = [] if trials is None else trials
        self.trials: list[_Trial] = [
            _Trial(
                id_=trial.trial_num,
                study=self,
                sampler=sampler,
                resume_trial=trial,
                optim_metrics=optim_metrics,
            )
            for trial in trials
        ]
        self.sampler = sampler
        self.directions = []
        self.optim_metrics = optim_metrics
        for v in optim_metrics.values():
            if v in {Optim.min, "min"}:
                self.directions.append(StudyDirection.MINIMIZE)
            elif v in {Optim.max, "max"}:
                self.directions.append(StudyDirection.MAXIMIZE)
            else:
                raise ValueError(f"Unrecognized optim. Direction `{v}`")

    def get_trials(self, deepcopy, states):
        assert deepcopy is False
        return [t for t in self.trials if t.state in states]

    def _is_multi_objective(self):
        return len(self.directions) > 1

    def drop(self, trial_id):
        for i, trial in enumerate(self.trials):
            if trial.id_ == trial_id:
                del self.trials[i]
                return
        raise RuntimeError(f"Trial {trial_id} was not found.")

    def make_trial(self):
        id_ = max(t.id_ for t in self.trials) + 1 if len(self.trials) > 0 else 0
        return _Trial(
            id_=id_,
            study=self,
            sampler=self.sampler,
            optim_metrics=self.optim_metrics,
        )

    def get_trial(self, trial_id):
        for trial in self.trials:
            if trial.id_ == trial_id:
                return trial
        raise RuntimeError(f"Trial {trial_id} was not found.")

    def update(self, trial_id, metrics, state):
        self.get_trial(trial_id).update(metrics, state)


[docs]class OptunaSampler(BaseSampler): """ OptunaSampler this class serves as an interface for Optuna based samplers. WARNING: The class will be refactored in future versions and should not be used by library users. The class is meant for internal use only. """ def __init__( self, search_algo: SearchAlgo, search_space: dict[str, SearchSpace], optim_metrics: collections.OrderedDict[str, Optim], trials: list["_state_store.Trial"] | None = None, seed: int | None = None, ): super().__init__() self.sampler: optuna.samplers.TPESampler | optuna.samplers.RandomSampler assert ( len(optim_metrics) > 0 ), "Need to specify 'optim_metrics' with `OptunaSampler`" self.optim_metrics = OrderedDict(optim_metrics) if search_algo == SearchAlgo.tpe: with warnings.catch_warnings(): warnings.simplefilter("ignore") self.sampler = optuna.samplers.TPESampler(constant_liar=True, seed=seed) elif search_algo == SearchAlgo.random: self.sampler = optuna.samplers.RandomSampler(seed=seed) else: raise ValueError(f"Unrecognized search algorithm: {search_algo}.") self._study = _Study(optim_metrics, sampler=self.sampler, trials=trials) self.search_space = search_space def _eager_sample(self): trial = self._study.make_trial() config = self._sample_trial_params(trial, self.search_space) trial.state = TrialState.RUNNING self._study.trials.append(trial) return trial.id_, config def _drop(self): self._study.trials.pop() def _suggest(self, trial: _Trial, name, dist): if trial.is_relative_param(name, dist): val = trial.relative_params[name] else: val = self.sampler.sample_independent(self._study, trial, name, dist) trial.params[name] = val trial.distributions[name] = dist return val def _suggest_int( self, trial: _Trial, name: str, value_range: tuple[int, int] | tuple[float, float], log: bool = False, n_bins: int | None = None, ): low, high = value_range if n_bins is None: step = 1 else: step = max((high - low) // n_bins, 1) dist = IntDistribution(low, high, log=log, step=step) return self._suggest(trial, name, dist) def _suggest_float( self, trial: _Trial, name: str, value_range: tuple[int, int] | tuple[float, float], log: bool = False, n_bins: int | None = None, ): low, high = value_range if n_bins is None: step = n_bins else: step = (high - low) / n_bins dist = FloatDistribution(low, high, log=log, step=step) return self._suggest(trial, name, dist) def _suggest_categorical(self, trial: _Trial, name: str, vals: list[str]): dist = CategoricalDistribution(choices=vals) return self._suggest(trial, name, dist) def _sample_trial_params( self, trial: _Trial, search_space: dict[str, SearchSpace | dict], ) -> dict[str, ty.Any]: parameter: dict[str, ty.Any] = {} def _sample_params( v, prefix: str = "", ): if isinstance(v, dict): return { _k: _sample_params(_v, prefix=f"{prefix}.{_k}") for _k, _v in v.items() } if not isinstance(v, SearchSpace): return v if v.value_range is not None and v.value_type == FieldType.discrete: return self._suggest_int( trial, prefix, v.parsed_value_range(), v.log, v.n_bins ) if v.value_range is not None and v.value_type == FieldType.continuous: return self._suggest_float( trial, prefix, v.parsed_value_range(), v.log, v.n_bins ) if v.categorical_values is not None: return self._suggest_categorical(trial, prefix, v.categorical_values) if v.subspaces is not None: # TODO make it non-random e.g. pick the best sub-configuration. # Can use a dummy categorical variable idx = np.random.choice(len(v.subspaces)) return _sample_params( v.subspaces[idx], prefix=f"{prefix}_{idx}", ) if v.sub_configuration is not None: return { _k: _sample_params(_v, prefix=f"{prefix}.{_k}") for _k, _v in v.sub_configuration.arguments.items() } raise ValueError(f"Invalid SearchSpace {v}.") for k, v in search_space.items(): parameter[k] = _sample_params(v, k) return parameter def update_trial( self, trial_id, metrics: dict[str, float] | None, state: TrialState ): self._study.update(trial_id, metrics, state) def internal_repr(self, trial_id): params = self._study.get_trial(trial_id).params distributions = self._study.get_trial(trial_id).distributions return { "_opt_params": params, "_opt_distributions_kwargs": { k: v.__dict__ for k, v in distributions.items() }, "_opt_distributions_types": { k: type(v).__name__ for k, v in distributions.items() }, }