import logging
import typing as ty
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from PIL import Image
from ablator.analysis.main import Analysis
from ablator.analysis.plot import Plot
from ablator.analysis.plot.cat_plot import ViolinPlot
from ablator.analysis.plot.num_plot import LinearPlot
from ablator.config.proto import Optim
logger = logging.getLogger(__name__)
[docs]class PlotAnalysis(Analysis):
"""
Class for plotting experiment results. You can use this class and ``Results`` class to visualize the
relationship between the result metrics and any hyperparameter you run ablation study on. This
valuable insight offers an intuitive understanding of how these parameters may influence your
model's performance.
Plots supported are linear plots for numerical data and violin plots for categorical data.
Parameters
----------
results : pd.DataFrame | Results
The result dataframe.
categorical_attributes : list[str] | None
The list of all the categorical hyperparameter names
numerical_attributes : list[str] | None
The list of all the numerical hyperparameter names
optim_metrics : dict[str, Optim] | None
A dictionary mapping metric names to optimization directions.
save_dir : str | None
The directory to save analysis results to.
cache : bool
Whether to cache results.
Examples
--------
- Data frame to be used:
>>> df = pd.DataFrame({'val_accuracy': np.random.uniform(0.8,0.9,10),
... 'train_config.optimizer_config.arguments.lr': np.random.uniform(0.001, 0.1,10),
... "index": range(10),
... "path": range(10)})
- Creating dictionaries that map the configuration parameters [categorical + numerical] to custom labels for plots:
>>> numerical_name_remap = {
... "train_config.optimizer_config.arguments.lr": "Learning Rate",
... }
... categorical_name_remap = {}
... attribute_name_remap = {**categorical_name_remap, **numerical_name_remap}
- Initalize the ``PlotAnalysis`` and plot the figures:
>>> analysis = PlotAnalysis(
... df,
... save_dir="./plots",
... cache=True,
... optim_metrics={"val_accuracy": Optim.max},
... numerical_attributes=list(numerical_name_remap.keys()),
... categorical_attributes=list(categorical_name_remap.keys()),
... )
>>> analysis.make_figures(
... metric_name_remap={
... "val_accuracy": "Validation Accuracy",
... },
... attribute_name_remap= attribute_name_remap
... )
The directory ``"plots"`` contains all the plots of the HPO experiments
"""
@classmethod
def _write_images(
cls,
fig_map: dict[str, ty.Union[Axes, Figure, Image.Image]],
path: Path,
file_format: ty.Literal["png", "pdf", "jpg"] = "png",
):
"""
Write images to a directory based on figure types.
Parameters
----------
fig_map : dict[str, ty.Union[Axes, Figure, Image.Image]]
A dictionary mapping names to ``matplotlib`` objects.
path : Path
Path to save the images to.
file_format : ty.Literal["png", "pdf", "jpg"]
the file format to save the images as.
Examples
--------
>>> fig_map = {"figure1": plt.subplots()[0]}
>>> path = Path("output_dir")
>>> PlotAnalysis._write_images(fig_map, path, "png")
"""
path.mkdir(exist_ok=True, parents=True)
for name, fig in fig_map.items():
img_path = path.joinpath(f"{name}.{file_format}")
if isinstance(fig, Axes):
fig.figure.savefig(img_path)
elif isinstance(fig, Figure):
fig.savefig(img_path)
elif isinstance(fig, Image.Image):
fig.save(img_path)
@classmethod
def _make_metric_plots(
cls,
path: Path | None,
plot_cls: type[Plot],
metrics: pd.DataFrame,
results: pd.DataFrame,
metric_map: dict[str, Optim],
append: bool = False,
ax: Axes | None = None,
metric_name_remap: dict[str, str] | None = None,
attribute_name_remap: dict[str, str] | None = None,
**kwargs,
) -> dict:
"""
Create the attributes vs. metrics plots.
Parameters
----------
path : Path | None
A ``pathlib.Path`` object representing the directory to write images to.
plot_cls : type[Plot]
A subclass of ``Plot`` representing the type of plot to make.
metrics : pd.DataFrame
A pandas ``DataFrame`` containing metric values.
results : pd.DataFrame
A pandas ``DataFrame`` containing attribute values.
metric_map : dict[str, Optim]
A dictionary mapping metric names to optimization functions.
append : bool
A boolean indicating whether to append plots to an existing axes object.
ax : Axes | None
A ``matplotlib.axes.Axes`` object representing the axis to plot on.
metric_name_remap : dict[str, str] | None
An optional dictionary mapping metric names to new metric names.
attribute_name_remap : dict[str, str] | None
An optional dictionary mapping attribute names to new attribute names.
**kwargs
Additional keyword arguments to pass to the plot method.
Returns
-------
dict
A dictionary of metric name to its axes_map
Examples
--------
>>> metrics = pd.DataFrame({"metric1": [1, 2, 3], "metric2": [4, 5, 6]})
>>> results = pd.DataFrame({"attr1": [7, 8, 9], "attr2": [10, 11, 12]})
>>> metric_map = {"metric1": Optim.max, "metric2": Optim.min}
>>> PlotAnalysis._make_metric_plots(None, LinearPlot, metrics, results, metric_map, False, None, None, None)
"""
axes = {}
(results, metrics, metric_map) = cls._remap_results(
results, metrics, metric_map, metric_name_remap, attribute_name_remap
)
inv_metric_name_remap = None
if metric_name_remap is not None:
inv_metric_name_remap = {v: k for k, v in metric_name_remap.items()}
inv_attribute_name_remap = None
if attribute_name_remap is not None:
inv_attribute_name_remap = {v: k for k, v in attribute_name_remap.items()}
for metric_name in metrics.columns:
metric_values = metrics[metric_name]
metric_obj_fn = metric_map[metric_name]
axes_map = cls._make_plot(
plot_cls=plot_cls,
metric_values=metric_values,
metric_obj_fn=metric_obj_fn,
results=results,
append=append,
ax=ax,
inv_attribute_name_map=inv_attribute_name_remap,
**kwargs,
)
if inv_metric_name_remap is not None:
original_metric_name = inv_metric_name_remap[metric_name]
else:
original_metric_name = metric_name
if path is not None:
p = path.joinpath(original_metric_name)
cls._write_images(axes_map, p)
for axe in axes_map.values():
plt.close(axe.figure)
axes[metric_name] = axes_map
return axes
@classmethod
def _make_plot(
cls,
metric_values: pd.Series,
results: pd.DataFrame,
plot_cls: type[Plot],
metric_obj_fn: Optim,
append=False,
ax: Axes | None = None,
inv_attribute_name_map: dict[str, str] | None = None,
**kwargs,
) -> dict[str, Axes]:
axes_map = {}
for attribute_name in results.columns:
attribute_values = results[attribute_name]
_, axes = plot_cls(
metric=metric_values,
attributes=attribute_values,
metric_obj_fn=metric_obj_fn,
y_axis=metric_values.name,
x_axis=attribute_name,
ax=ax,
).make(**kwargs)
if inv_attribute_name_map is not None:
original_attribute_name = inv_attribute_name_map[attribute_name]
else:
original_attribute_name = attribute_name
if append:
ax = axes
else:
axes_map[original_attribute_name] = axes
plt.close()
if append:
axes_map["combined_attribute"] = axes
plt.close()
return axes_map
[docs] def make_violinplot(
self,
attribute_names: list[str],
metrics: list[str],
save_dir: ty.Union[Path, str],
**plt_kwargs,
):
"""
Make violin plots for the given attribute names (data type: discrete) v.s. the metrics
and save the plots to the `save_dir` directory.
Parameters
----------
attribute_names : list[str]
list of attributes to plot against the metrics.
metrics : list[str]
list of metrics to plot against the given attributes.
save_dir : ty.Union[Path, str]
directory to save results.
**plt_kwargs
Additional keyword arguments to pass to the plot.
"""
save_path = Path(save_dir).joinpath("violinplot")
metric_map = {k: v for k, v in self.optim_metrics.items() if k in metrics}
self._make_metric_plots(
path=save_path,
plot_cls=ViolinPlot,
metrics=self.results[metrics],
results=self.results[attribute_names],
metric_map=metric_map,
**plt_kwargs,
)
[docs] def make_linearplot(
self,
attribute_names: list[str],
metrics: list[str],
save_dir: ty.Union[Path, str],
**plt_kwargs,
) -> dict:
"""
To make linear plots for the given attribute names (data type: numerical) v.s. the metrics
and save the plots to the `save_dir` directory.
Parameters
----------
attribute_names : list[str]
list of attributes to plot against the metrics.
metrics : list[str]
list of metrics to plot against the given attributes.
save_dir : ty.Union[Path, str]
directory to save results.
**plt_kwargs
Additional keyword arguments to pass to the plot method.
Returns
-------
dict
"""
save_path = Path(save_dir).joinpath("linearplot")
metric_map = {k: v for k, v in self.optim_metrics.items() if k in metrics}
return self._make_metric_plots(
path=save_path,
plot_cls=LinearPlot,
metrics=self.results[metrics],
results=self.results[attribute_names],
metric_map=metric_map,
**plt_kwargs,
)
# flake8: noqa: DOC102