Skip to content

Commit d763109

Browse files
committed
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also
1 parent a96b5b6 commit d763109

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

test/test_api/test_api.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def test_portfolio_selection(openml_id, backend, n_samples):
522522
resampling_strategy=HoldoutValTypes.holdout_validation,
523523
)
524524

525+
<<<<<<< HEAD
525526
with unittest.mock.patch.object(estimator, '_do_dummy_prediction', new=dummy_do_dummy_prediction):
526527
estimator.search(
527528
X_train=X_train, y_train=y_train,
@@ -533,6 +534,62 @@ def test_portfolio_selection(openml_id, backend, n_samples):
533534
portfolio_selection=os.path.join(os.path.dirname(__file__),
534535
"../../autoPyTorch/configs/greedy_portfolio.json")
535536
)
537+
=======
538+
dataset = estimator.get_dataset(X_train=X_train,
539+
y_train=y_train,
540+
X_test=X_test,
541+
y_test=y_test,
542+
resampling_strategy=resampling_strategy,
543+
resampling_strategy_args=resampling_strategy_args)
544+
545+
configuration = estimator.get_search_space(dataset).get_default_configuration()
546+
pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset,
547+
configuration=configuration,
548+
run_time_limit_secs=50,
549+
budget_type='epochs',
550+
budget=30,
551+
disable_file_output=disable_file_output,
552+
eval_metric='balanced_accuracy'
553+
)
554+
assert isinstance(dataset, BaseDataset)
555+
assert isinstance(run_info, RunInfo)
556+
assert isinstance(run_info.config, Configuration)
557+
558+
assert isinstance(run_value, RunValue)
559+
assert 'SUCCESS' in str(run_value.status)
560+
561+
if not disable_file_output:
562+
563+
if resampling_strategy in CrossValTypes:
564+
pytest.skip("Bug, Can't predict with cross validation pipeline")
565+
assert isinstance(pipeline, BaseEstimator)
566+
X_test = dataset.test_tensors[0]
567+
preds = pipeline.predict(X_test)
568+
assert isinstance(preds, np.ndarray)
569+
570+
score = accuracy(dataset.test_tensors[1], preds)
571+
assert isinstance(score, float)
572+
assert score > 0.8
573+
else:
574+
# To make sure we fitted the model, there should be a
575+
# run summary object
576+
run_summary = pipeline.named_steps['trainer'].run_summary
577+
assert run_summary is not None
578+
# test to ensure balanced_accuracy is reported during training
579+
assert 'balanced_accuracy' in run_summary.performance_tracker['train_metrics'][1].keys()
580+
581+
assert isinstance(pipeline, BasePipeline)
582+
X_test = dataset.test_tensors[0]
583+
preds = pipeline.predict(X_test)
584+
assert isinstance(preds, np.ndarray)
585+
586+
score = accuracy(dataset.test_tensors[1], preds)
587+
assert isinstance(score, float)
588+
assert score > 0.8
589+
else:
590+
assert pipeline is None
591+
assert run_value.cost < 0.2
592+
>>>>>>> Additional metrics during train (#194)
536593

537594
successful_config_ids = [run_key.config_id for run_key, run_value in estimator.run_history.data.items(
538595
) if 'SUCCESS' in str(run_value.status)]

0 commit comments

Comments
 (0)