Skip to content

Commit 1e77132

Browse files
committed
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also
1 parent 688f5cf commit 1e77132

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,16 @@ def __init__(self, backend: Backend,
378378
'backend': self.backend,
379379
'logger_port': logger_port,
380380
})
381+
382+
# Update fit dictionary with metrics passed to the evaluator
383+
metrics_dict: Dict[str, List[str]] = {'additional_metrics': []}
384+
metrics_dict['additional_metrics'].append(self.metric.name)
385+
if all_supported_metrics:
386+
assert self.additional_metrics is not None
387+
for metric in self.additional_metrics:
388+
metrics_dict['additional_metrics'].append(metric.name)
389+
self.fit_dictionary.update(metrics_dict)
390+
381391
assert self.pipeline_class is not None, "Could not infer pipeline class"
382392
pipeline_config = pipeline_config if pipeline_config is not None \
383393
else self.pipeline_class.get_default_pipeline_options()

test/test_api/test_api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def test_pipeline_fit(openml_id,
503503
run_time_limit_secs=50,
504504
budget_type='epochs',
505505
budget=30,
506-
disable_file_output=disable_file_output
506+
disable_file_output=disable_file_output,
507+
eval_metric='balanced_accuracy'
507508
)
508509
assert isinstance(dataset, BaseDataset)
509510
assert isinstance(run_info, RunInfo)
@@ -513,6 +514,7 @@ def test_pipeline_fit(openml_id,
513514
assert 'SUCCESS' in str(run_value.status)
514515

515516
if not disable_file_output:
517+
516518
if resampling_strategy in CrossValTypes:
517519
pytest.skip("Bug, Can't predict with cross validation pipeline")
518520
assert isinstance(pipeline, BaseEstimator)
@@ -524,11 +526,14 @@ def test_pipeline_fit(openml_id,
524526
assert isinstance(score, float)
525527
assert score > 0.8
526528
else:
527-
assert isinstance(pipeline, BasePipeline)
528529
# To make sure we fitted the model, there should be a
529-
# run summary object with accuracy
530+
# run summary object
530531
run_summary = pipeline.named_steps['trainer'].run_summary
531532
assert run_summary is not None
533+
# test to ensure balanced_accuracy is reported during training
534+
assert 'balanced_accuracy' in run_summary.performance_tracker['train_metrics'][1].keys()
535+
536+
assert isinstance(pipeline, BasePipeline)
532537
X_test = dataset.test_tensors[0]
533538
preds = pipeline.predict(X_test)
534539
assert isinstance(preds, np.ndarray)

0 commit comments

Comments
 (0)