@@ -932,6 +932,61 @@ def test_pipeline_fit(openml_id,
932
932
assert not os .path .exists (cv_model_path )
933
933
934
934
935
+ @pytest .mark .parametrize ('openml_id,budget' , [(40984 , 1 )])
936
+ def test_pipeline_fit_pass_pipeline_options (
937
+ openml_id ,
938
+ backend ,
939
+ budget ,
940
+ n_samples
941
+ ):
942
+ # Get the data and check that contents of data-manager make sense
943
+ X , y = sklearn .datasets .fetch_openml (
944
+ data_id = int (openml_id ),
945
+ return_X_y = True , as_frame = True
946
+ )
947
+ X_train , X_test , y_train , y_test = sklearn .model_selection .train_test_split (
948
+ X [:n_samples ], y [:n_samples ], random_state = 1 )
949
+
950
+ # Search for a good configuration
951
+ estimator = TabularClassificationTask (
952
+ backend = backend ,
953
+ ensemble_size = 0
954
+ )
955
+
956
+ dataset = estimator .get_dataset (X_train = X_train ,
957
+ y_train = y_train ,
958
+ X_test = X_test ,
959
+ y_test = y_test )
960
+
961
+ configuration = estimator .get_search_space (dataset ).get_default_configuration ()
962
+ pipeline , run_info , run_value , dataset = estimator .fit_pipeline (dataset = dataset ,
963
+ configuration = configuration ,
964
+ run_time_limit_secs = 50 ,
965
+ budget_type = 'epochs' ,
966
+ budget = budget ,
967
+ pipeline_options = {'early_stopping' : 100 }
968
+ )
969
+ assert isinstance (dataset , BaseDataset )
970
+ assert isinstance (run_info , RunInfo )
971
+ assert isinstance (run_info .config , Configuration )
972
+
973
+ assert isinstance (run_value , RunValue )
974
+ assert 'SUCCESS' in str (run_value .status )
975
+
976
+ # Make sure that the pipeline can be pickled
977
+ dump_file = os .path .join (tempfile .gettempdir (), 'automl.dump.pkl' )
978
+ with open (dump_file , 'wb' ) as f :
979
+ pickle .dump (pipeline , f )
980
+
981
+ num_run_dir = estimator ._backend .get_numrun_directory (
982
+ run_info .seed , run_value .additional_info ['num_run' ], budget = float (budget ))
983
+ model_path = os .path .join (num_run_dir , estimator ._backend .get_model_filename (
984
+ run_info .seed , run_value .additional_info ['num_run' ], budget = float (budget )))
985
+
986
+ # We expect the model path always
987
+ assert os .path .exists (model_path )
988
+
989
+
935
990
@pytest .mark .parametrize ('openml_id' , (40984 ,))
936
991
@pytest .mark .parametrize ('resampling_strategy,resampling_strategy_args' ,
937
992
((HoldoutValTypes .holdout_validation , {'val_share' : 0.8 }),
0 commit comments