Skip to content

Commit d26cc3c

Browse files
committed
[refactor] Gather params in Dataclass to look smarter
1 parent 226441b commit d26cc3c

File tree

2 files changed

+165
-239
lines changed

2 files changed

+165
-239
lines changed

autoPyTorch/api/base_task.py

Lines changed: 17 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from smac.stats.stats import Stats
3232
from smac.tae import StatusType
3333

34-
from autoPyTorch.api.run_history_visualizer import RunHistoryVisualizer
34+
from autoPyTorch.api.run_history_visualizer import ColorLabelSettings, PlotSettingParams, RunHistoryVisualizer
3535
from autoPyTorch.automl_common.common.utils.backend import Backend, create
3636
from autoPyTorch.constants import (
3737
REGRESSION_TASKS,
@@ -1384,34 +1384,9 @@ def _print_debug_info_to_log(self) -> None:
13841384
def plot_perf_over_time(
13851385
self,
13861386
metric_name: str,
1387-
include_single_train: bool = True,
1388-
include_single_opt: bool = True,
1389-
include_single_test: bool = True,
1390-
include_ensemble_train: bool = True,
1391-
include_ensemble_test: bool = True,
1392-
color_single_train: str = 'red',
1393-
color_single_opt: str = 'blue',
1394-
color_single_test: str = 'green',
1395-
color_ensemble_train: str = 'brown',
1396-
color_ensemble_test: str = 'purple',
1397-
label_single_train: Optional[str] = None,
1398-
label_single_opt: Optional[str] = None,
1399-
label_single_test: Optional[str] = None,
1400-
label_ensemble_train: Optional[str] = None,
1401-
label_ensemble_test: Optional[str] = None,
14021387
ax: Optional[plt.Axes] = None,
1403-
n_points: int = 20,
1404-
xlabel: Optional[str] = None,
1405-
ylabel: Optional[str] = None,
1406-
xscale: str = 'log',
1407-
yscale: str = 'linear',
1408-
title: Optional[str] = None,
1409-
xlim: Optional[Tuple[float, float]] = None,
1410-
ylim: Optional[Tuple[float, float]] = None,
1411-
figsize: Optional[Tuple[int, int]] = None,
1412-
legend: bool = True,
1413-
legend_loc: Optional[str] = 'best',
1414-
show: bool = False,
1388+
plot_setting_params: PlotSettingParams = PlotSettingParams(),
1389+
color_label_settings: ColorLabelSettings = ColorLabelSettings(),
14151390
*args: Any,
14161391
**kwargs: Any
14171392
) -> None:
@@ -1426,77 +1401,28 @@ def plot_perf_over_time(
14261401
The names are available in
14271402
* autoPyTorch.metrics.CLASSIFICATION_METRICS
14281403
* autoPyTorch.metrics.REGRESSION_METRICS
1429-
include_single_(train, opt, test) (bool):
1430-
Whether to include single train/opt/test performance
1431-
to the plot.
1432-
include_ensemble_(train, test) (bool):
1433-
Whether to include ensemble train/test performance
1434-
to the plot.
1435-
color_single_(train, opt, test) (str):
1436-
What color to use for single train/opt/test performance.
1437-
color_ensemble_(train, test) (str):
1438-
What color to use for ensemble train/opt/test performance.
1439-
label_single_(train, opt, test) (bool):
1440-
What label in the legend to use for single train/opt/test performance.
1441-
label_ensemble_(train, test) (bool):
1442-
What label in the legend to use for ensemble train/opt/test performance.
14431404
ax (Optional[plt.Axes]):
14441405
axis to plot (subplots of matplotlib).
14451406
If None, it will be created automatically.
1446-
n_points (int):
1447-
The number of points to plot.
1448-
labels (Dict[str, str]):
1449-
The name of each plot.
1450-
xlabel (Optional[str]):
1451-
The label in the x axis.
1452-
ylabel (Optional[str]):
1453-
The label in the y axis.
1454-
xscale (str):
1455-
The scale of x axis.
1456-
yscale (str):
1457-
The scale of y axis.
1458-
xscale (Tuple[float, float]):
1459-
The range of x axis.
1460-
yscale (Tuple[float, float]):
1461-
The range of y axis.
1462-
title (Optional[str]):
1463-
The title of the subfigure.
1464-
figsize (Optional[Tuple[int, int]]):
1465-
The figure size.
1466-
legend (bool):
1467-
Whether to have legend in the figure.
1468-
legend_loc (str):
1469-
The location of the legend.
1470-
show (bool):
1471-
Whether to show the plot.
1407+
plot_setting_params (PlotSettingParams):
1408+
Parameters for the plot.
1409+
color_label_settings (ColorLabelSettings):
1410+
The settings of a pair of color and label for each plot.
14721411
args, kwargs (Any):
14731412
Arguments for the ax.plot.
14741413
"""
14751414

1476-
colors = {
1477-
f'single::train::{metric_name}' if include_single_train else '': color_single_train,
1478-
f'single::opt::{metric_name}' if include_single_opt else '': color_single_opt,
1479-
f'single::test::{metric_name}' if include_single_test else '': color_single_test,
1480-
f'ensemble::train::{metric_name}' if include_ensemble_train else '': color_ensemble_train,
1481-
f'ensemble::test::{metric_name}' if include_ensemble_test else '': color_ensemble_test,
1482-
}
1483-
colors.pop('', None) # Remove if the include_xxx is False
1484-
labels = {
1485-
f'single::train::{metric_name}' if include_single_train else '': label_single_train,
1486-
f'single::opt::{metric_name}' if include_single_opt else '': label_single_opt,
1487-
f'single::test::{metric_name}' if include_single_test else '': label_single_test,
1488-
f'ensemble::train::{metric_name}' if include_ensemble_train else '': label_ensemble_train,
1489-
f'ensemble::test::{metric_name}' if include_ensemble_test else '': label_ensemble_test,
1490-
}
1491-
labels.pop('', None) # Remove if the include_xxx is False
1492-
1493-
perf_metric_names = list(colors.keys())
1415+
colors, labels = {}, {}
1416+
for key, color_label in vars(color_label_settings).items():
1417+
if color_label is None:
1418+
continue
1419+
1420+
new_key = '::'.join(key.split('_'))
1421+
colors[new_key], labels[new_key] = color_label
14941422

14951423
self._visualizer.plot_perf_over_time(
1496-
perf_metric_names=perf_metric_names,
1497-
run_history=self.run_history,
1424+
metric_name=metric_name, ax=ax, colors=colors, labels=labels,
1425+
plot_setting_params=plot_setting_params, run_history=self.run_history,
14981426
ensemble_performance_history=self.ensemble_performance_history,
1499-
colors=colors, xlabel=xlabel, ylabel=ylabel, xscale=xscale, yscale=yscale,
1500-
ax=ax, n_points=n_points, labels=labels, xlim=xlim, ylim=ylim, title=title,
1501-
figsize=figsize, legend=legend, legend_loc=legend_loc, show=show, *args, **kwargs
1427+
*args, **kwargs
15021428
)

0 commit comments

Comments
 (0)