@@ -522,6 +522,7 @@ def test_portfolio_selection(openml_id, backend, n_samples):
522522 resampling_strategy = HoldoutValTypes .holdout_validation ,
523523 )
524524
525+ < << << << HEAD
525526 with unittest .mock .patch .object (estimator , '_do_dummy_prediction' , new = dummy_do_dummy_prediction ):
526527 estimator .search (
527528 X_train = X_train , y_train = y_train ,
@@ -533,6 +534,62 @@ def test_portfolio_selection(openml_id, backend, n_samples):
533534 portfolio_selection = os .path .join (os .path .dirname (__file__ ),
534535 "../../autoPyTorch/configs/greedy_portfolio.json" )
535536 )
537+ == == == =
538+ dataset = estimator .get_dataset (X_train = X_train ,
539+ y_train = y_train ,
540+ X_test = X_test ,
541+ y_test = y_test ,
542+ resampling_strategy = resampling_strategy ,
543+ resampling_strategy_args = resampling_strategy_args )
544+
545+ configuration = estimator .get_search_space (dataset ).get_default_configuration ()
546+ pipeline , run_info , run_value , dataset = estimator .fit_pipeline (dataset = dataset ,
547+ configuration = configuration ,
548+ run_time_limit_secs = 50 ,
549+ budget_type = 'epochs' ,
550+ budget = 30 ,
551+ disable_file_output = disable_file_output ,
552+ eval_metric = 'balanced_accuracy'
553+ )
554+ assert isinstance (dataset , BaseDataset )
555+ assert isinstance (run_info , RunInfo )
556+ assert isinstance (run_info .config , Configuration )
557+
558+ assert isinstance (run_value , RunValue )
559+ assert 'SUCCESS' in str (run_value .status )
560+
561+ if not disable_file_output :
562+
563+ if resampling_strategy in CrossValTypes :
564+ pytest .skip ("Bug, Can't predict with cross validation pipeline" )
565+ assert isinstance (pipeline , BaseEstimator )
566+ X_test = dataset .test_tensors [0 ]
567+ preds = pipeline .predict (X_test )
568+ assert isinstance (preds , np .ndarray )
569+
570+ score = accuracy (dataset .test_tensors [1 ], preds )
571+ assert isinstance (score , float )
572+ assert score > 0.8
573+ else :
574+ # To make sure we fitted the model, there should be a
575+ # run summary object
576+ run_summary = pipeline .named_steps ['trainer' ].run_summary
577+ assert run_summary is not None
578+ # test to ensure balanced_accuracy is reported during training
579+ assert 'balanced_accuracy' in run_summary .performance_tracker ['train_metrics' ][1 ].keys ()
580+
581+ assert isinstance (pipeline , BasePipeline )
582+ X_test = dataset .test_tensors [0 ]
583+ preds = pipeline .predict (X_test )
584+ assert isinstance (preds , np .ndarray )
585+
586+ score = accuracy (dataset .test_tensors [1 ], preds )
587+ assert isinstance (score , float )
588+ assert score > 0.8
589+ else :
590+ assert pipeline is None
591+ assert run_value .cost < 0.2
592+ >> >> >> > Additional metrics during train (#194)
536593
537594 successful_config_ids = [run_key .config_id for run_key , run_value in estimator .run_history .data .items (
538595 ) if 'SUCCESS' in str (run_value .status )]
0 commit comments