|
21 | 21 |
|
22 | 22 | import joblib
|
23 | 23 |
|
| 24 | +import matplotlib.pyplot as plt |
| 25 | + |
24 | 26 | import numpy as np
|
25 | 27 |
|
26 | 28 | import pandas as pd
|
|
29 | 31 | from smac.stats.stats import Stats
|
30 | 32 | from smac.tae import StatusType
|
31 | 33 |
|
32 |
| -from autoPyTorch.api.results_manager import ResultsManager, SearchResults |
| 34 | +from autoPyTorch import metrics |
33 | 35 | from autoPyTorch.automl_common.common.utils.backend import Backend, create
|
34 | 36 | from autoPyTorch.constants import (
|
35 | 37 | REGRESSION_TASKS,
|
|
58 | 60 | )
|
59 | 61 | from autoPyTorch.utils.parallel import preload_modules
|
60 | 62 | from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements
|
| 63 | +from autoPyTorch.utils.results_manager import MetricResults, ResultsManager, SearchResults |
| 64 | +from autoPyTorch.utils.results_visualizer import ColorLabelSettings, PlotSettingParams, ResultsVisualizer |
61 | 65 | from autoPyTorch.utils.single_thread_client import SingleThreadedClient
|
62 | 66 | from autoPyTorch.utils.stopwatch import StopWatch
|
63 | 67 |
|
@@ -1479,3 +1483,56 @@ def sprint_statistics(self) -> str:
|
1479 | 1483 | scoring_functions=self._scoring_functions,
|
1480 | 1484 | metric=self._metric
|
1481 | 1485 | )
|
| 1486 | + |
| 1487 | + def plot_perf_over_time( |
| 1488 | + self, |
| 1489 | + metric_name: str, |
| 1490 | + ax: Optional[plt.Axes] = None, |
| 1491 | + plot_setting_params: PlotSettingParams = PlotSettingParams(), |
| 1492 | + color_label_settings: ColorLabelSettings = ColorLabelSettings(), |
| 1493 | + *args: Any, |
| 1494 | + **kwargs: Any |
| 1495 | + ) -> None: |
| 1496 | + """ |
| 1497 | + Visualize the performance over time using matplotlib. |
| 1498 | + The plot related arguments are based on matplotlib. |
| 1499 | + Please refer to the matplotlib documentation for more details. |
| 1500 | +
|
| 1501 | + Args: |
| 1502 | + metric_name (str): |
| 1503 | + The name of metric to visualize. |
| 1504 | + The names are available in |
| 1505 | + * autoPyTorch.metrics.CLASSIFICATION_METRICS |
| 1506 | + * autoPyTorch.metrics.REGRESSION_METRICS |
| 1507 | + ax (Optional[plt.Axes]): |
| 1508 | + axis to plot (subplots of matplotlib). |
| 1509 | + If None, it will be created automatically. |
| 1510 | + plot_setting_params (PlotSettingParams): |
| 1511 | + Parameters for the plot. |
| 1512 | + color_label_settings (ColorLabelSettings): |
| 1513 | + The settings of a pair of color and label for each plot. |
| 1514 | + args, kwargs (Any): |
| 1515 | + Arguments for the ax.plot. |
| 1516 | + """ |
| 1517 | + |
| 1518 | + if not hasattr(metrics, metric_name): |
| 1519 | + raise ValueError( |
| 1520 | + f'metric_name must be in {list(metrics.CLASSIFICATION_METRICS.keys())} ' |
| 1521 | + f'or {list(metrics.REGRESSION_METRICS.keys())}, but got {metric_name}' |
| 1522 | + ) |
| 1523 | + if len(self.ensemble_performance_history) == 0: |
| 1524 | + raise RuntimeError('Visualization is available only after ensembles are evaluated.') |
| 1525 | + |
| 1526 | + results = MetricResults( |
| 1527 | + metric=getattr(metrics, metric_name), |
| 1528 | + run_history=self.run_history, |
| 1529 | + ensemble_performance_history=self.ensemble_performance_history |
| 1530 | + ) |
| 1531 | + |
| 1532 | + colors, labels = color_label_settings.extract_dicts(results) |
| 1533 | + |
| 1534 | + ResultsVisualizer().plot_perf_over_time( # type: ignore |
| 1535 | + results=results, plot_setting_params=plot_setting_params, |
| 1536 | + colors=colors, labels=labels, ax=ax, |
| 1537 | + *args, **kwargs |
| 1538 | + ) |
0 commit comments