from collections import abc
import copy
import inspect
import logging
import operator
import typing as ty
from typing import Any, Union
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing_extensions import Self
from omegaconf import OmegaConf
from ablator.config.types import (
Annotation,
Derived,
Dict,
Enum,
List,
Literal,
Stateless,
Tuple,
Type,
parse_type_hint,
parse_value,
)
from ablator.config.utils import dict_hash, flatten_nested_dict, parse_repr_to_kwargs
[docs]def configclass(cls: type["ConfigBase"]) -> type["ConfigBase"]:
"""
Decorator for ``ConfigBase`` subclasses, adds the ``config_class`` attribute to the class.
Parameters
----------
cls : type["ConfigBase"]
The class to be decorated.
Returns
-------
type[ConfigBase]
The decorated class with the ``config_class`` attribute.
"""
assert issubclass(cls, ConfigBase), f"{cls.__name__} must inherit from ConfigBase"
setattr(cls, "config_class", cls)
return dataclass(cls, init=False, repr=False, kw_only=True, eq=False) # type: ignore[call-overload]
def _freeze_helper(obj):
def __setattr__(self, k, v):
if getattr(self, "_freeze", False):
raise RuntimeError(
f"Can not set attribute {k} on a class of a frozen configuration ``{type(self).__name__}``."
)
super(type(self), self).__setattr__(k, v)
try:
obj._freeze = True # pylint: disable=protected-access
type(obj).__setattr__ = __setattr__
except Exception: # pylint: disable=broad-exception-caught
# this is the case where the object does not have
# attribute setter function
pass
def _unfreeze_helper(obj):
if hasattr(obj, "_freeze"):
super(type(obj), obj).__setattr__("_freeze", False)
def _parse_reconstructor(val, ignore_stateless: bool, flatten: bool):
if isinstance(val, (int, float, bool, str, type(None))):
return val
if issubclass(type(val), ConfigBase):
return val.make_dict(
val.annotations, ignore_stateless=ignore_stateless, flatten=flatten
)
if issubclass(type(val), Enum):
return val.value
args, kwargs = parse_repr_to_kwargs(val)
if len(args) == 0:
return kwargs
if len(kwargs) == 0:
return args
return args, kwargs
[docs]class Missing:
"""
This type is defined only for raising an error
"""
[docs]@dataclass(repr=False)
class ConfigBase:
# NOTE: this allows for non-defined arguments to be created. It is very bug-prone and will be disabled.
"""
This class is the building block for all configuration objects within ablator. It serves as the base class for
configurations such as ``ModelConfig``, ``TrainConfig``, ``OptimizerConfig``, and more. Together with
``@configclass``, it allows for the creation of config classes of customized attributes without the need to
define a constructor. ``ConfigBase`` and ``@configclass`` take care of the initialization and parsing of the
attributes. The example section below shows this in more detail.
In summary, to customize configurations for specific needs, you can create your own configuration class by
inheriting it from ``ConfigBase``. It's essential to annotate it with ``@configclass``. In the tutorial
`Search space for different types of optimizers and scheduler
<./notebooks/Searchspace-for-diff-optimizers.ipynb>`_, a custom optimizer config class is created to enable
ablation study on various optimizers and schedulers. You can refer to this tutorial for a realistic example of
how to create your custom configuration class.
.. note::
One key takeaway is that when initializing a config object, you can look into the list of attributes defined
in the config class to see what arguments you can pass.
Parameters
----------
*args : Any
This argument is just for disabling passing by positional arguments.
debug : bool, optional
Whether to load the configuration in debug mode and ignore discrepancies/errors, by default ``False``.
**kwargs : Any
Keyword arguments. Possible arguments are from the annotations of the configuration class. You can look into the
Examples section for more details.
Attributes
----------
config_class : Type
The class of the configuration object.
Raises
------
ValueError
If positional arguments are provided or there are missing required values.
KeyError
If unexpected arguments are provided.
RuntimeError
If the class is not decorated with ``@configclass``.
.. note::
All config classes must be decorated with ``@configclass``.
Examples
--------
>>> @configclass
>>> class MyCustomConfig(ConfigBase):
... attr1: int = 1
... attr2: Tuple[str, int, str]
>>> my_config = MyCustomConfig(attr1=4, attr2=("hello", 1, "world")) # Pass by named arguments
>>> kwargs = {"attr1": 4, "attr2": ("hello", 1, "world")} # Pass by keyword arguments
>>> my_config = MyCustomConfig(**kwargs)
Note that since we defined ``MyCustomConfig`` as a config class with two annotated attributes ``attr1``
and ``attr2`` (without a constructor, which is automatically handled by ``ConfigBase`` and
``@configclass``), when creating the config object, you can directly pass ``attr1`` and ``attr2``. You
can also pass these arguments as keyword arguments.
"""
config_class = type(None)
def __init__(self, *args: Any, debug: bool = False, **kwargs: Any):
self._debug: bool
self._freeze: bool
self._class_name: str
self.__setattr__internal("_debug", debug)
self.__setattr__internal("_freeze", False)
self.__setattr__internal("_class_name", type(self).__name__)
missing_vals = self._validate_inputs(*args, debug=debug, **kwargs)
assert len(missing_vals) == 0 or debug
for k in self.annotations:
if k in kwargs:
v = kwargs[k]
del kwargs[k]
else:
v = getattr(self, k, None)
if k in missing_vals:
logging.warning(
"Loading %s in `debug` mode. Setting missing required value %s to `None`.",
self._class_name,
k,
)
self.__setattr__internal(k, None)
else:
try:
setattr(self, k, v)
except Exception as e: # pylint: disable=broad-exception-caught
if not debug:
raise e
logging.warning(
"Loading %s in `debug` mode. Unable to parse `%s` value %s. Setting to `None`.",
self._class_name,
k,
v,
)
self.__setattr__internal(k, None)
if len(kwargs) > 0 and not debug:
unspected_args = ", ".join(kwargs.keys())
raise KeyError(f"Unexpected arguments: `{unspected_args}`")
if len(kwargs) > 0:
unspected_args = ", ".join(kwargs.keys())
logging.warning(
"Loading %s in `debug` mode. Ignoring unexpected arguments: `%s`",
self._class_name,
unspected_args,
)
def _validate_inputs(self, *args, debug: bool, **kwargs) -> list[str]:
added_variables = {
item[0]
for item in inspect.getmembers(type(self))
if not inspect.isfunction(item[1]) and not item[0].startswith("_")
}
base_variables = {
item[0]
for item in inspect.getmembers(ConfigBase)
if not inspect.isfunction(item[1])
}
non_annotated_variables = (
added_variables - base_variables - set(self.annotations.keys())
)
assert (
len(non_annotated_variables) == 0
), f"All variables must be annotated. {non_annotated_variables}"
if len(args) > 0:
raise ValueError(
f"{self._class_name} does not support positional arguments."
)
if not isinstance(self, self.config_class): # type: ignore[arg-type]
raise RuntimeError(
f"You must decorate your Config class '{self._class_name}' with ablator.configclass."
)
missing_vals = self._validate_missing(**kwargs)
if len(missing_vals) != 0 and not debug:
raise ValueError(f"Missing required values {missing_vals}.")
return missing_vals
def _validate_missing(self, **kwargs) -> list[str]:
missing_vals = []
for k, annotation in self.annotations.items():
if not annotation.optional and annotation.state not in [Derived]:
# make sure non-optional and derived values are not empty or
# without a default assignment
if not (
(k in kwargs and kwargs[k] is not None)
or getattr(self, k, None) is not None
):
missing_vals.append(k)
return missing_vals
def __setattr__internal(self, k, v):
super().__setattr__(k, v)
def __setattr__(self, k, v):
if self._freeze:
raise RuntimeError(
f"Can not set attribute {k} on frozen configuration ``{type(self).__name__}``."
)
annotation = self.annotations[k]
v = parse_value(v, annotation, k, self._debug)
self.__setattr__internal(k, v)
def __eq__(self, other):
if isinstance(other, self.__class__):
return len(self.diff(other)) == 0
return False
def __repr__(self) -> str:
"""
Return the string representation of the configuration object.
Returns
-------
str
The string representation of the configuration object.
"""
return (
self._class_name
+ "("
+ ", ".join(
[
f"{k}='{v}'" if isinstance(v, str) else f"{k}={v.__repr__()}"
for k, v in self.to_dict().items()
]
)
+ ")"
)
[docs] def keys(self) -> abc.KeysView[str]:
"""
Get the keys of the configuration dictionary.
Returns
-------
abc.KeysView[str]
The keys of the configuration dictionary.
"""
return self.to_dict().keys()
[docs] @classmethod
def load(cls, path: Union[Path, str], debug: bool = False) -> Self:
"""
Load a configuration object from a file.
Parameters
----------
path : Union[Path, str]
The path to the configuration file.
debug : bool, optional
Whether to load the configuration in debug mode, and ignore discrepancies/errors,
by default ``False``.
Returns
-------
Self
The loaded configuration object.
"""
# TODO[iordanis] remove OmegaConf dependency
kwargs: dict = OmegaConf.to_object( # type: ignore[assignment]
OmegaConf.create(Path(path).read_text(encoding="utf-8"))
)
return cls(**kwargs, debug=debug)
@property
def annotations(self) -> dict[str, Annotation]:
"""
Get the parsed annotations of the configuration object.
Returns
-------
dict[str, Annotation]
A dictionary of parsed annotations.
"""
annotations = {}
if hasattr(self, "__annotations__"):
annotation_types = dict(self.__annotations__)
# pylint: disable=no-member
# Without the if statement it will over-write new configurations
# e.x.
# class ReConfig(RunConfig):
# train_config: SomeTrainConfig = SomeTrainConfig()
# model_config: SomeModelConfig = SomeModelConfig()
# TODO test-me
dataclass_types = {
k: v.type
for k, v in self.__dataclass_fields__.items()
if k not in annotation_types
}
annotation_types.update(dataclass_types)
annotations = {
field_name: parse_type_hint(type(self), annotation)
for field_name, annotation in annotation_types.items()
}
return annotations
[docs] def get_val_with_dot_path(self, dot_path: str) -> Any:
"""
Get the value of a configuration object attribute using dot notation.
Parameters
----------
dot_path : str
The dot notation path to the attribute.
Returns
-------
Any
The value of the attribute.
"""
return operator.attrgetter(dot_path)(self)
[docs] def get_type_with_dot_path(self, dot_path: str) -> Type:
"""
Get the type of a configuration object attribute using dot notation.
Parameters
----------
dot_path : str
The dot notation path to the attribute.
Returns
-------
Type
The type of the attribute.
"""
val = self.get_val_with_dot_path(dot_path)
return type(val)
[docs] def get_annot_type_with_dot_path(self, dot_path: str) -> Type:
"""
Get the type of a configuration object annotation using dot notation.
Parameters
----------
dot_path : str
The dot notation path to the annotation.
Returns
-------
Type
The type of the annotation.
"""
*base_path, element = dot_path.split(".")
annot_dot_path = ".".join(base_path + ["annotations"])
annot: dict[str, Annotation] = self.get_val_with_dot_path(annot_dot_path)
return annot[element].variable_type
# pylint: disable=too-complex
[docs] def make_dict(
self,
annotations: dict[str, Annotation],
ignore_stateless: bool = False,
flatten: bool = False,
) -> dict:
"""
Create a dictionary representation of the configuration object.
Parameters
----------
annotations : dict[str, Annotation]
A dictionary of annotations.
ignore_stateless : bool
Whether to ignore stateless values, by default ``False``.
flatten : bool
Whether to flatten nested dictionaries, by default ``False``.
Returns
-------
dict
The dictionary representation of the configuration object.
Raises
------
NotImplementedError
If the type of annot.collection is not supported.
"""
return_dict = {}
parse_reconstructor = partial(
_parse_reconstructor, ignore_stateless=ignore_stateless, flatten=flatten
)
for field_name, annot in annotations.items():
if ignore_stateless and (annot.state in {Stateless, Derived}):
continue
_val = getattr(self, field_name)
if annot.collection in [None, Literal] or _val is None:
val = _val
elif annot.collection == List:
val = [parse_reconstructor(_lval) for _lval in _val]
elif annot.collection == Tuple:
val = tuple(parse_reconstructor(_lval) for _lval in _val)
elif annot.collection in [Dict]:
val = {k: parse_reconstructor(_dval) for k, _dval in _val.items()}
elif issubclass(type(_val), ConfigBase):
val = _val.make_dict(
_val.annotations, ignore_stateless=ignore_stateless, flatten=flatten
)
elif annot.collection == Type:
if annot.optional and _val is None:
val = None
else:
val = parse_reconstructor(_val)
elif issubclass(type(_val), Enum):
val = _val.value
else:
raise NotImplementedError
return_dict[field_name] = val
if flatten:
return_dict = flatten_nested_dict(return_dict)
return return_dict
[docs] def write(self, path: Union[Path, str]):
"""
Write the configuration object to a file.
Parameters
----------
path : Union[Path, str]
The path to the file.
"""
Path(path).write_text(self.to_yaml(), encoding="utf-8")
[docs] def diff_str(
self, config: "ConfigBase", ignore_stateless: bool = False
) -> list[str]:
"""
Get the differences between the current configuration object and another configuration object as strings.
Parameters
----------
config : ConfigBase
The configuration object to compare.
ignore_stateless : bool
Whether to ignore stateless values, by default ``False``.
Returns
-------
list[str]
The list of differences as strings.
"""
diffs = self.diff(config, ignore_stateless=ignore_stateless)
str_diffs = []
for p, (l_t, l_v), (r_t, r_v) in diffs:
_diff = f"{p}:({l_t.__name__}){l_v}->({r_t.__name__}){r_v}"
str_diffs.append(_diff)
return str_diffs
[docs] def diff(
self, config: "ConfigBase", ignore_stateless: bool = False
) -> list[tuple[str, tuple[type, Any], tuple[type, Any]]]:
"""
Get the differences between the current configuration object and another configuration object.
Parameters
----------
config : ConfigBase
The configuration object to compare.
ignore_stateless : bool
Whether to ignore stateless values, by default ``False``
Returns
-------
list[tuple[str, tuple[type, Any], tuple[type, Any]]]
The list of differences as tuples.
Examples
--------
Let's say we have two configuration objects ``config1`` and ``config2`` with the following attributes:
>>> config1:
learning_rate: 0.01
optimizer: 'Adam'
num_layers: 3
>>> config2:
learning_rate: 0.02
optimizer: 'SGD'
num_layers: 3
The diff between these two configurations would look like:
>>> config1.diff(config2)
[('learning_rate', (float, 0.01), (float, 0.02)), ('optimizer', (str, 'Adam'), (str, 'SGD'))]
In this example, the learning_rate and optimizer values are different between the two configuration objects.
"""
left_config = copy.deepcopy(self)
right_config = copy.deepcopy(config)
left_dict = left_config.make_dict(
left_config.annotations, ignore_stateless=ignore_stateless, flatten=True
)
right_dict = right_config.make_dict(
right_config.annotations, ignore_stateless=ignore_stateless, flatten=True
)
left_keys = set(left_dict.keys())
right_keys = set(right_dict.keys())
diffs: list[tuple[str, tuple[type, ty.Any], tuple[type, ty.Any]]] = []
for k in left_keys.union(right_keys):
if k not in left_dict:
right_v = right_dict[k]
right_type = type(right_v)
diffs.append((k, (Missing, None), (right_type, right_v)))
elif k not in right_dict:
left_v = left_dict[k]
left_type = type(left_v)
diffs.append((k, (left_type, left_v), (Missing, None)))
elif left_dict[k] != right_dict[k] or not isinstance(
left_dict[k], type(right_dict[k])
):
right_v = right_dict[k]
left_v = left_dict[k]
left_type = type(left_v)
right_type = type(right_v)
diffs.append((k, (left_type, left_v), (right_type, right_v)))
return diffs
[docs] def to_dict(self, ignore_stateless: bool = False) -> dict:
"""
Convert the configuration object to a dictionary.
Parameters
----------
ignore_stateless : bool
Whether to ignore stateless values, by default ``False``.
Returns
-------
dict
The dictionary representation of the configuration object.
"""
return self.make_dict(self.annotations, ignore_stateless=ignore_stateless)
[docs] def to_yaml(self) -> str:
"""
Convert the configuration object to YAML format.
Returns
-------
str
The YAML representation of the configuration object.
"""
# TODO: investigate https://github.com/crdoconnor/strictyaml as an alternative to OmegaConf
conf = OmegaConf.create(self.to_dict())
return OmegaConf.to_yaml(conf)
[docs] def to_dot_path(self, ignore_stateless: bool = False) -> str:
"""
Convert the configuration object to a dictionary with dot notation paths as keys.
Parameters
----------
ignore_stateless : bool
Whether to ignore stateless values, by default ``False``.
Returns
-------
str
The YAML representation of the configuration object in dot notation paths.
"""
_flat_dict = self.make_dict(
self.annotations, ignore_stateless=ignore_stateless, flatten=True
)
return OmegaConf.to_yaml(OmegaConf.create(_flat_dict))
@property
def uid(self) -> str:
"""
Get the unique identifier for the configuration object.
Returns
-------
str
The unique identifier for the configuration object.
"""
return dict_hash(self.make_dict(self.annotations, ignore_stateless=True))[:5]
[docs] def assert_unambigious(self):
"""
Assert that the configuration object is unambiguous and has all the required values.
Raises
------
AssertionError
If the configuration object is ambiguous or missing required values.
"""
for k, annot in self.annotations.items():
if not annot.optional:
assert (
getattr(self, k) is not None
), f"Ambiguous configuration `{self._class_name}`. Must provide value for {k}"
self._apply_lambda_recursively("assert_unambigious")
[docs] def freeze(self):
self.__setattr__internal("_freeze", True)
self._apply_lambda_recursively("freeze")
for k, annot in self.annotations.items():
if (
isinstance(annot.variable_type, type)
and not issubclass(annot.variable_type, ConfigBase)
and getattr(self, k) is not None
and hasattr(getattr(self, k), "__setattr__")
):
if annot.collection in [List, Tuple]:
for _lval in getattr(self, k):
_freeze_helper(_lval)
elif annot.collection in [Dict]:
for _lval in getattr(self, k).values():
_freeze_helper(_lval)
else:
_freeze_helper(getattr(self, k))
def _unfreeze(self):
self.__setattr__internal("_freeze", False)
self._apply_lambda_recursively("_unfreeze")
for k, annot in self.annotations.items():
if (
isinstance(annot.variable_type, type)
and not issubclass(annot.variable_type, ConfigBase)
and getattr(self, k) is not None
):
if annot.collection in [List, Tuple]:
for _lval in getattr(self, k):
_unfreeze_helper(_lval)
elif annot.collection in [Dict]:
for _lval in getattr(self, k).values():
_unfreeze_helper(_lval)
else:
_unfreeze_helper(getattr(self, k))
def _apply_lambda_recursively(self, lam: str, *args):
for k, annot in self.annotations.items():
if (
isinstance(annot.variable_type, type)
and issubclass(annot.variable_type, ConfigBase)
and getattr(self, k) is not None
):
if annot.collection in [List, Tuple]:
for _lval in getattr(self, k):
getattr(_lval, lam)(*args)
elif annot.collection in [Dict]:
for _lval in getattr(self, k).values():
getattr(_lval, lam)(*args)
else:
getattr(getattr(self, k), lam)(*args)