Skip to content

Commit 6992609

Browse files
nabenabe0928ravinkohli
authored andcommitted
[feat] Add an object that realizes the perf over time viz (#331)
* [feat] Add an object that realizes the perf over time viz * [fix] Modify TODOs and add comments to avoid complications * [refactor] [feat] Format visualizer API and integrate this feature into BaseTask * [refactor] Separate a shared raise error process as a function * [refactor] Gather params in Dataclass to look smarter * [refactor] Merge extraction from history to the result manager Since this feature was added in a previous PR, we now rely on this feature to extract the history. To handle the order by the start time issue, I added the sort by endtime feature. * [feat] Merge the viz in the latest version * [fix] Fix nan --> worst val so that we can always handle by number * [fix] Fix mypy issues * [test] Add test for get_start_time * [test] Add test for order by end time * [test] Add tests for ensemble results * [test] Add tests for merging ensemble results and run history * [test] Add the tests in the case of ensemble_results is None * [fix] Alternate datetime to timestamp in tests to pass universally Since the mapping of timestamp to datetime variates on machine, the tests failed in the previous version. In this version, we changed the datetime in the tests to the fixed timestamp so that the tests will pass universally. * [fix] Fix status_msg --> status_type because it does not need to be str * [fix] Change the name for the homogeniety * [fix] Fix based on the file name change * [test] Add tests for set_plot_args * [test] Add tests for plot_perf_over_time in BaseTask * [refactor] Replace redundant lines by pytest parametrization * [test] Add tests for _get_perf_and_time * [fix] Remove viz attribute based on Ravin's comment * [fix] Fix doc-string based on Ravin's comments * [refactor] Hide color label settings extraction in dataclass Since this process makes the method in BaseTask redundant and this was pointed out by Ravin, I made this process a method of dataclass so that we can easily fetch this information. Note that since the color and label information always depend on the optimization results, we always need to pass metric results to ensure we only get related keys. * [test] Add tests for color label dicts extraction * [test] Add tests for checking if plt.show is called or not * [refactor] Address Ravin's comments and add TODO for the refactoring * [refactor] Change KeyError in EnsembleResults to empty Since it is not convenient to not be able to instantiate EnsembleResults in the case when we do not have any histories, I changed the functionality so that we can still instantiate even when the results are empty. In this case, we have empty arrays and it also matches the developers intuition. * [refactor] Prohibit external updates to make objects more robust * [fix] Remove a member variable _opt_scores since it is confusing Since opt_scores are taken from cost in run_history and metric_dict takes from additional_info, it was confusing for me where I should refer to what. By removing this, we can always refer to additional_info when fetching information and metrics are always available as a raw value. Although I changed a lot, the functionality did not change and it is easier to add any other functionalities now. * [example] Add an example how to plot performance over time * [fix] Fix unexpected train loss when using cross validation * [fix] Remove __main__ from example based on the Ravin's comment * [fix] Move results_xxx to utils from API * [enhance] Change example for the plot over time to save fig Since the plt.show() does not work on some environments, I changed the example so that everyone can run at least this example.
1 parent 54ee98e commit 6992609

10 files changed

+1903
-563
lines changed

autoPyTorch/api/base_task.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import joblib
2323

24+
import matplotlib.pyplot as plt
25+
2426
import numpy as np
2527

2628
import pandas as pd
@@ -29,7 +31,7 @@
2931
from smac.stats.stats import Stats
3032
from smac.tae import StatusType
3133

32-
from autoPyTorch.api.results_manager import ResultsManager, SearchResults
34+
from autoPyTorch import metrics
3335
from autoPyTorch.automl_common.common.utils.backend import Backend, create
3436
from autoPyTorch.constants import (
3537
REGRESSION_TASKS,
@@ -58,6 +60,8 @@
5860
)
5961
from autoPyTorch.utils.parallel import preload_modules
6062
from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements
63+
from autoPyTorch.utils.results_manager import MetricResults, ResultsManager, SearchResults
64+
from autoPyTorch.utils.results_visualizer import ColorLabelSettings, PlotSettingParams, ResultsVisualizer
6165
from autoPyTorch.utils.single_thread_client import SingleThreadedClient
6266
from autoPyTorch.utils.stopwatch import StopWatch
6367

@@ -1479,3 +1483,56 @@ def sprint_statistics(self) -> str:
14791483
scoring_functions=self._scoring_functions,
14801484
metric=self._metric
14811485
)
1486+
1487+
def plot_perf_over_time(
1488+
self,
1489+
metric_name: str,
1490+
ax: Optional[plt.Axes] = None,
1491+
plot_setting_params: PlotSettingParams = PlotSettingParams(),
1492+
color_label_settings: ColorLabelSettings = ColorLabelSettings(),
1493+
*args: Any,
1494+
**kwargs: Any
1495+
) -> None:
1496+
"""
1497+
Visualize the performance over time using matplotlib.
1498+
The plot related arguments are based on matplotlib.
1499+
Please refer to the matplotlib documentation for more details.
1500+
1501+
Args:
1502+
metric_name (str):
1503+
The name of metric to visualize.
1504+
The names are available in
1505+
* autoPyTorch.metrics.CLASSIFICATION_METRICS
1506+
* autoPyTorch.metrics.REGRESSION_METRICS
1507+
ax (Optional[plt.Axes]):
1508+
axis to plot (subplots of matplotlib).
1509+
If None, it will be created automatically.
1510+
plot_setting_params (PlotSettingParams):
1511+
Parameters for the plot.
1512+
color_label_settings (ColorLabelSettings):
1513+
The settings of a pair of color and label for each plot.
1514+
args, kwargs (Any):
1515+
Arguments for the ax.plot.
1516+
"""
1517+
1518+
if not hasattr(metrics, metric_name):
1519+
raise ValueError(
1520+
f'metric_name must be in {list(metrics.CLASSIFICATION_METRICS.keys())} '
1521+
f'or {list(metrics.REGRESSION_METRICS.keys())}, but got {metric_name}'
1522+
)
1523+
if len(self.ensemble_performance_history) == 0:
1524+
raise RuntimeError('Visualization is available only after ensembles are evaluated.')
1525+
1526+
results = MetricResults(
1527+
metric=getattr(metrics, metric_name),
1528+
run_history=self.run_history,
1529+
ensemble_performance_history=self.ensemble_performance_history
1530+
)
1531+
1532+
colors, labels = color_label_settings.extract_dicts(results)
1533+
1534+
ResultsVisualizer().plot_perf_over_time( # type: ignore
1535+
results=results, plot_setting_params=plot_setting_params,
1536+
colors=colors, labels=labels, ax=ax,
1537+
*args, **kwargs
1538+
)

0 commit comments

Comments
 (0)