Skip to content

Commit 452c629

Browse files
[Rebase]
1 parent c45087b commit 452c629

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

test/test_api/test_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
7272

7373
# Internal dataset has expected settings
7474
assert estimator.dataset.task_type == 'tabular_classification'
75-
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3
75+
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5
7676
assert estimator.resampling_strategy == resampling_strategy
7777
assert estimator.dataset.resampling_strategy == resampling_strategy
7878
assert len(estimator.dataset.splits) == expected_num_splits
@@ -140,7 +140,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
140140
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
141141
estimator.seed, successful_num_run, run_key.budget)
142142
assert isinstance(model, VotingClassifier)
143-
assert len(model.estimators_) == 3
143+
assert len(model.estimators_) == 5
144144
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
145145
torch.nn.Module)
146146
else:
@@ -243,7 +243,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
243243

244244
# Internal dataset has expected settings
245245
assert estimator.dataset.task_type == 'tabular_regression'
246-
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3
246+
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5
247247
assert estimator.resampling_strategy == resampling_strategy
248248
assert estimator.dataset.resampling_strategy == resampling_strategy
249249
assert len(estimator.dataset.splits) == expected_num_splits
@@ -310,7 +310,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
310310
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
311311
estimator.seed, successful_num_run, run_key.budget)
312312
assert isinstance(model, VotingRegressor)
313-
assert len(model.estimators_) == 3
313+
assert len(model.estimators_) == 5
314314
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
315315
torch.nn.Module)
316316
else:

0 commit comments

Comments
 (0)