Skip to content

Commit 6d17012

Browse files
authored
[FIX] pipeline options in fit_pipeline (#466)
* fix update of pipeline config options in fit pipeline * fix flake and test * suggestions from review
1 parent 4bcc583 commit 6d17012

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

autoPyTorch/api/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,7 @@ def fit_pipeline(
16321632
names=[eval_metric] if eval_metric is not None else None,
16331633
all_supported_metrics=False).pop()
16341634

1635-
pipeline_options = self.pipeline_options.copy().update(pipeline_options) if pipeline_options is not None \
1635+
pipeline_options = {**self.pipeline_options, **pipeline_options} if pipeline_options is not None \
16361636
else self.pipeline_options.copy()
16371637

16381638
assert pipeline_options is not None

test/test_api/test_api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,61 @@ def test_pipeline_fit(openml_id,
932932
assert not os.path.exists(cv_model_path)
933933

934934

935+
@pytest.mark.parametrize('openml_id,budget', [(40984, 1)])
936+
def test_pipeline_fit_pass_pipeline_options(
937+
openml_id,
938+
backend,
939+
budget,
940+
n_samples
941+
):
942+
# Get the data and check that contents of data-manager make sense
943+
X, y = sklearn.datasets.fetch_openml(
944+
data_id=int(openml_id),
945+
return_X_y=True, as_frame=True
946+
)
947+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
948+
X[:n_samples], y[:n_samples], random_state=1)
949+
950+
# Search for a good configuration
951+
estimator = TabularClassificationTask(
952+
backend=backend,
953+
ensemble_size=0
954+
)
955+
956+
dataset = estimator.get_dataset(X_train=X_train,
957+
y_train=y_train,
958+
X_test=X_test,
959+
y_test=y_test)
960+
961+
configuration = estimator.get_search_space(dataset).get_default_configuration()
962+
pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset,
963+
configuration=configuration,
964+
run_time_limit_secs=50,
965+
budget_type='epochs',
966+
budget=budget,
967+
pipeline_options={'early_stopping': 100}
968+
)
969+
assert isinstance(dataset, BaseDataset)
970+
assert isinstance(run_info, RunInfo)
971+
assert isinstance(run_info.config, Configuration)
972+
973+
assert isinstance(run_value, RunValue)
974+
assert 'SUCCESS' in str(run_value.status)
975+
976+
# Make sure that the pipeline can be pickled
977+
dump_file = os.path.join(tempfile.gettempdir(), 'automl.dump.pkl')
978+
with open(dump_file, 'wb') as f:
979+
pickle.dump(pipeline, f)
980+
981+
num_run_dir = estimator._backend.get_numrun_directory(
982+
run_info.seed, run_value.additional_info['num_run'], budget=float(budget))
983+
model_path = os.path.join(num_run_dir, estimator._backend.get_model_filename(
984+
run_info.seed, run_value.additional_info['num_run'], budget=float(budget)))
985+
986+
# We expect the model path always
987+
assert os.path.exists(model_path)
988+
989+
935990
@pytest.mark.parametrize('openml_id', (40984,))
936991
@pytest.mark.parametrize('resampling_strategy,resampling_strategy_args',
937992
((HoldoutValTypes.holdout_validation, {'val_share': 0.8}),

0 commit comments

Comments
 (0)