31
31
from smac .stats .stats import Stats
32
32
from smac .tae import StatusType
33
33
34
- from autoPyTorch .api .run_history_visualizer import RunHistoryVisualizer
34
+ from autoPyTorch .api .run_history_visualizer import ColorLabelSettings , PlotSettingParams , RunHistoryVisualizer
35
35
from autoPyTorch .automl_common .common .utils .backend import Backend , create
36
36
from autoPyTorch .constants import (
37
37
REGRESSION_TASKS ,
@@ -1384,34 +1384,9 @@ def _print_debug_info_to_log(self) -> None:
1384
1384
def plot_perf_over_time (
1385
1385
self ,
1386
1386
metric_name : str ,
1387
- include_single_train : bool = True ,
1388
- include_single_opt : bool = True ,
1389
- include_single_test : bool = True ,
1390
- include_ensemble_train : bool = True ,
1391
- include_ensemble_test : bool = True ,
1392
- color_single_train : str = 'red' ,
1393
- color_single_opt : str = 'blue' ,
1394
- color_single_test : str = 'green' ,
1395
- color_ensemble_train : str = 'brown' ,
1396
- color_ensemble_test : str = 'purple' ,
1397
- label_single_train : Optional [str ] = None ,
1398
- label_single_opt : Optional [str ] = None ,
1399
- label_single_test : Optional [str ] = None ,
1400
- label_ensemble_train : Optional [str ] = None ,
1401
- label_ensemble_test : Optional [str ] = None ,
1402
1387
ax : Optional [plt .Axes ] = None ,
1403
- n_points : int = 20 ,
1404
- xlabel : Optional [str ] = None ,
1405
- ylabel : Optional [str ] = None ,
1406
- xscale : str = 'log' ,
1407
- yscale : str = 'linear' ,
1408
- title : Optional [str ] = None ,
1409
- xlim : Optional [Tuple [float , float ]] = None ,
1410
- ylim : Optional [Tuple [float , float ]] = None ,
1411
- figsize : Optional [Tuple [int , int ]] = None ,
1412
- legend : bool = True ,
1413
- legend_loc : Optional [str ] = 'best' ,
1414
- show : bool = False ,
1388
+ plot_setting_params : PlotSettingParams = PlotSettingParams (),
1389
+ color_label_settings : ColorLabelSettings = ColorLabelSettings (),
1415
1390
* args : Any ,
1416
1391
** kwargs : Any
1417
1392
) -> None :
@@ -1426,77 +1401,28 @@ def plot_perf_over_time(
1426
1401
The names are available in
1427
1402
* autoPyTorch.metrics.CLASSIFICATION_METRICS
1428
1403
* autoPyTorch.metrics.REGRESSION_METRICS
1429
- include_single_(train, opt, test) (bool):
1430
- Whether to include single train/opt/test performance
1431
- to the plot.
1432
- include_ensemble_(train, test) (bool):
1433
- Whether to include ensemble train/test performance
1434
- to the plot.
1435
- color_single_(train, opt, test) (str):
1436
- What color to use for single train/opt/test performance.
1437
- color_ensemble_(train, test) (str):
1438
- What color to use for ensemble train/opt/test performance.
1439
- label_single_(train, opt, test) (bool):
1440
- What label in the legend to use for single train/opt/test performance.
1441
- label_ensemble_(train, test) (bool):
1442
- What label in the legend to use for ensemble train/opt/test performance.
1443
1404
ax (Optional[plt.Axes]):
1444
1405
axis to plot (subplots of matplotlib).
1445
1406
If None, it will be created automatically.
1446
- n_points (int):
1447
- The number of points to plot.
1448
- labels (Dict[str, str]):
1449
- The name of each plot.
1450
- xlabel (Optional[str]):
1451
- The label in the x axis.
1452
- ylabel (Optional[str]):
1453
- The label in the y axis.
1454
- xscale (str):
1455
- The scale of x axis.
1456
- yscale (str):
1457
- The scale of y axis.
1458
- xscale (Tuple[float, float]):
1459
- The range of x axis.
1460
- yscale (Tuple[float, float]):
1461
- The range of y axis.
1462
- title (Optional[str]):
1463
- The title of the subfigure.
1464
- figsize (Optional[Tuple[int, int]]):
1465
- The figure size.
1466
- legend (bool):
1467
- Whether to have legend in the figure.
1468
- legend_loc (str):
1469
- The location of the legend.
1470
- show (bool):
1471
- Whether to show the plot.
1407
+ plot_setting_params (PlotSettingParams):
1408
+ Parameters for the plot.
1409
+ color_label_settings (ColorLabelSettings):
1410
+ The settings of a pair of color and label for each plot.
1472
1411
args, kwargs (Any):
1473
1412
Arguments for the ax.plot.
1474
1413
"""
1475
1414
1476
- colors = {
1477
- f'single::train::{ metric_name } ' if include_single_train else '' : color_single_train ,
1478
- f'single::opt::{ metric_name } ' if include_single_opt else '' : color_single_opt ,
1479
- f'single::test::{ metric_name } ' if include_single_test else '' : color_single_test ,
1480
- f'ensemble::train::{ metric_name } ' if include_ensemble_train else '' : color_ensemble_train ,
1481
- f'ensemble::test::{ metric_name } ' if include_ensemble_test else '' : color_ensemble_test ,
1482
- }
1483
- colors .pop ('' , None ) # Remove if the include_xxx is False
1484
- labels = {
1485
- f'single::train::{ metric_name } ' if include_single_train else '' : label_single_train ,
1486
- f'single::opt::{ metric_name } ' if include_single_opt else '' : label_single_opt ,
1487
- f'single::test::{ metric_name } ' if include_single_test else '' : label_single_test ,
1488
- f'ensemble::train::{ metric_name } ' if include_ensemble_train else '' : label_ensemble_train ,
1489
- f'ensemble::test::{ metric_name } ' if include_ensemble_test else '' : label_ensemble_test ,
1490
- }
1491
- labels .pop ('' , None ) # Remove if the include_xxx is False
1492
-
1493
- perf_metric_names = list (colors .keys ())
1415
+ colors , labels = {}, {}
1416
+ for key , color_label in vars (color_label_settings ).items ():
1417
+ if color_label is None :
1418
+ continue
1419
+
1420
+ new_key = '::' .join (key .split ('_' ))
1421
+ colors [new_key ], labels [new_key ] = color_label
1494
1422
1495
1423
self ._visualizer .plot_perf_over_time (
1496
- perf_metric_names = perf_metric_names ,
1497
- run_history = self .run_history ,
1424
+ metric_name = metric_name , ax = ax , colors = colors , labels = labels ,
1425
+ plot_setting_params = plot_setting_params , run_history = self .run_history ,
1498
1426
ensemble_performance_history = self .ensemble_performance_history ,
1499
- colors = colors , xlabel = xlabel , ylabel = ylabel , xscale = xscale , yscale = yscale ,
1500
- ax = ax , n_points = n_points , labels = labels , xlim = xlim , ylim = ylim , title = title ,
1501
- figsize = figsize , legend = legend , legend_loc = legend_loc , show = show , * args , ** kwargs
1427
+ * args , ** kwargs
1502
1428
)
0 commit comments