Skip to content

Commit a8fb556

Browse files
committed
Addressed comments, better documentation and dict for runhistory
1 parent 1c2332e commit a8fb556

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

autoPyTorch/api/base_task.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __init__(
173173
self._dataset_requirements: Optional[List[FitRequirement]] = None
174174
self._metric: Optional[autoPyTorchMetric] = None
175175
self._logger: Optional[PicklableClientLogger] = None
176-
self.run_history: Optional[RunHistory] = None
176+
self.run_history: Dict = {}
177177
self.trajectory: Optional[List] = None
178178
self.dataset_name: Optional[str] = None
179179
self.cv_models_: Dict = {}
@@ -688,6 +688,10 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
688688
"Please consider increasing the run time to further improve performance.")
689689
break
690690

691+
self._logger.debug("Run history traditional: {}".format(run_history))
692+
# add run history of traditional to api run history
693+
self.run_history.update(run_history.data)
694+
run_history.save_json(os.path.join(self._backend.internals_directory, 'traditional_run_history.json'))
691695
return num_run
692696

693697
def _search(
@@ -958,8 +962,9 @@ def _search(
958962
search_space_updates=self.search_space_updates
959963
)
960964
try:
961-
self.run_history, self.trajectory, budget_type = \
965+
run_history, self.trajectory, budget_type = \
962966
_proc_smac.run_smbo()
967+
self.run_history.update(run_history.data)
963968
trajectory_filename = os.path.join(
964969
self._backend.get_smac_output_directory_for_run(self.seed),
965970
'trajectory.json')

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
8686
return False
8787

8888
def get_additional_run_info(self) -> Dict[str, Any]: # pylint: disable=R0201
89+
"""
90+
Can be used to return additional info for the run.
91+
Returns:
92+
Dict[str, Any]:
93+
Currently contains
94+
1. pipeline_configuration: the configuration of the pipeline, i.e, the traditional model used
95+
2. trainer_configuration: the parameters for the traditional model used.
96+
Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/classifier_configs
97+
"""
8998
return {'pipeline_configuration': self.configuration,
9099
'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()}
91100

0 commit comments

Comments
 (0)