Skip to content

Commit b5d1c8b

Browse files
authored
change config_id to config_id+1 (#129)
1 parent a3d40ac commit b5d1c8b

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

test/test_api/test_api.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,33 +105,33 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
105105
# Search for an existing run key in disc. A individual model might have
106106
# a timeout and hence was not written to disc
107107
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
108-
if i == 0:
109-
# Ignore dummy run
110-
continue
111108
if 'SUCCESS' not in str(value.status):
112109
continue
113110

114111
run_key_model_run_dir = estimator._backend.get_numrun_directory(
115-
estimator.seed, run_key.config_id, run_key.budget)
112+
estimator.seed, run_key.config_id + 1, run_key.budget)
116113
if os.path.exists(run_key_model_run_dir):
114+
# Runkey config id is different from the num_run
115+
# more specifically num_run = config_id + 1(dummy)
116+
successful_num_run = run_key.config_id + 1
117117
break
118118

119119
if resampling_strategy == HoldoutValTypes.holdout_validation:
120120
model_file = os.path.join(run_key_model_run_dir,
121-
f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.model")
121+
f"{estimator.seed}.{successful_num_run}.{run_key.budget}.model")
122122
assert os.path.exists(model_file), model_file
123123
model = estimator._backend.load_model_by_seed_and_id_and_budget(
124-
estimator.seed, run_key.config_id, run_key.budget)
124+
estimator.seed, successful_num_run, run_key.budget)
125125
assert isinstance(model.named_steps['network'].get_network(), torch.nn.Module)
126126
elif resampling_strategy == CrossValTypes.k_fold_cross_validation:
127127
model_file = os.path.join(
128128
run_key_model_run_dir,
129-
f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.cv_model"
129+
f"{estimator.seed}.{successful_num_run}.{run_key.budget}.cv_model"
130130
)
131131
assert os.path.exists(model_file), model_file
132132

133133
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
134-
estimator.seed, run_key.config_id, run_key.budget)
134+
estimator.seed, successful_num_run, run_key.budget)
135135
assert isinstance(model, VotingClassifier)
136136
assert len(model.estimators_) == 3
137137
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
@@ -142,7 +142,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
142142
# Make sure that predictions on the test data are printed and make sense
143143
test_prediction = os.path.join(run_key_model_run_dir,
144144
estimator._backend.get_prediction_filename(
145-
'test', estimator.seed, run_key.config_id,
145+
'test', estimator.seed, successful_num_run,
146146
run_key.budget))
147147
assert os.path.exists(test_prediction), test_prediction
148148
assert np.shape(np.load(test_prediction, allow_pickle=True))[0] == np.shape(X_test)[0]
@@ -152,7 +152,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
152152
ensemble_prediction = os.path.join(run_key_model_run_dir,
153153
estimator._backend.get_prediction_filename(
154154
'ensemble',
155-
estimator.seed, run_key.config_id,
155+
estimator.seed, successful_num_run,
156156
run_key.budget))
157157
assert os.path.exists(ensemble_prediction), ensemble_prediction
158158
assert np.shape(np.load(ensemble_prediction, allow_pickle=True))[0] == np.shape(
@@ -213,10 +213,16 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
213213
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
214214
X, y, random_state=1)
215215

216+
include = None
217+
# for python less than 3.7, learned entity embedding
218+
# is not able to be stored on disk (only on CI)
219+
if sys.version_info < (3, 7):
220+
include = {'network_embedding': ['NoEmbedding']}
216221
# Search for a good configuration
217222
estimator = TabularRegressionTask(
218223
backend=backend,
219224
resampling_strategy=resampling_strategy,
225+
include_components=include
220226
)
221227

222228
estimator.search(
@@ -267,32 +273,32 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
267273
# Search for an existing run key in disc. A individual model might have
268274
# a timeout and hence was not written to disc
269275
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
270-
if i == 0:
271-
# Ignore dummy run
272-
continue
273276
if 'SUCCESS' not in str(value.status):
274277
continue
275278

276279
run_key_model_run_dir = estimator._backend.get_numrun_directory(
277-
estimator.seed, run_key.config_id, run_key.budget)
280+
estimator.seed, run_key.config_id + 1, run_key.budget)
278281
if os.path.exists(run_key_model_run_dir):
282+
# Runkey config id is different from the num_run
283+
# more specifically num_run = config_id + 1(dummy)
284+
successful_num_run = run_key.config_id + 1
279285
break
280286

281287
if resampling_strategy == HoldoutValTypes.holdout_validation:
282288
model_file = os.path.join(run_key_model_run_dir,
283-
f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.model")
289+
f"{estimator.seed}.{successful_num_run}.{run_key.budget}.model")
284290
assert os.path.exists(model_file), model_file
285291
model = estimator._backend.load_model_by_seed_and_id_and_budget(
286-
estimator.seed, run_key.config_id, run_key.budget)
292+
estimator.seed, successful_num_run, run_key.budget)
287293
assert isinstance(model.named_steps['network'].get_network(), torch.nn.Module)
288294
elif resampling_strategy == CrossValTypes.k_fold_cross_validation:
289295
model_file = os.path.join(
290296
run_key_model_run_dir,
291-
f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.cv_model"
297+
f"{estimator.seed}.{successful_num_run}.{run_key.budget}.cv_model"
292298
)
293299
assert os.path.exists(model_file), model_file
294300
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
295-
estimator.seed, run_key.config_id, run_key.budget)
301+
estimator.seed, successful_num_run, run_key.budget)
296302
assert isinstance(model, VotingRegressor)
297303
assert len(model.estimators_) == 3
298304
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
@@ -303,7 +309,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
303309
# Make sure that predictions on the test data are printed and make sense
304310
test_prediction = os.path.join(run_key_model_run_dir,
305311
estimator._backend.get_prediction_filename(
306-
'test', estimator.seed, run_key.config_id,
312+
'test', estimator.seed, successful_num_run,
307313
run_key.budget))
308314
assert os.path.exists(test_prediction), test_prediction
309315
assert np.shape(np.load(test_prediction, allow_pickle=True))[0] == np.shape(X_test)[0]
@@ -313,7 +319,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
313319
ensemble_prediction = os.path.join(run_key_model_run_dir,
314320
estimator._backend.get_prediction_filename(
315321
'ensemble',
316-
estimator.seed, run_key.config_id,
322+
estimator.seed, successful_num_run,
317323
run_key.budget))
318324
assert os.path.exists(ensemble_prediction), ensemble_prediction
319325
assert np.shape(np.load(ensemble_prediction, allow_pickle=True))[0] == np.shape(

0 commit comments

Comments
 (0)