Skip to content

Commit c017fac

Browse files
committed
add test for get dataset
1 parent 8055406 commit c017fac

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

test/test_datasets/test_tabular_dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import pytest
44

5+
from autoPyTorch.data.tabular_validator import TabularInputValidator
6+
from autoPyTorch.datasets.base_dataset import TransformSubset
7+
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes
58
from autoPyTorch.datasets.tabular_dataset import TabularDataset
69
from autoPyTorch.utils.pipeline import get_dataset_requirements
710

@@ -46,3 +49,34 @@ def test_get_dataset_properties(backend, fit_dictionary_tabular):
4649
def test_not_supported():
4750
with pytest.raises(ValueError, match=r".*A feature validator is required to build.*"):
4851
TabularDataset(np.ones(10), np.ones(10))
52+
53+
54+
@pytest.mark.parametrize('resampling_strategy',
55+
(HoldoutValTypes.holdout_validation,
56+
CrossValTypes.k_fold_cross_validation,
57+
NoResamplingStrategyTypes.no_resampling
58+
))
59+
def test_get_dataset(resampling_strategy, n_samples):
60+
"""
61+
Checks the functionality of get_dataset function of the TabularDataset
62+
gives an error when trying to get training and validation subset
63+
"""
64+
X = np.zeros(shape=(n_samples, 4))
65+
Y = np.ones(n_samples)
66+
validator = TabularInputValidator(is_classification=True)
67+
validator.fit(X, Y)
68+
dataset = TabularDataset(
69+
resampling_strategy=resampling_strategy,
70+
X=X,
71+
Y=Y,
72+
validator=validator
73+
)
74+
transform_subset = dataset.get_dataset(split_id=0, train=True)
75+
assert isinstance(transform_subset, TransformSubset)
76+
77+
if isinstance(resampling_strategy, NoResamplingStrategyTypes):
78+
with pytest.raises(ValueError):
79+
dataset.get_dataset(split_id=0, train=False)
80+
else:
81+
transform_subset = dataset.get_dataset(split_id=0, train=False)
82+
assert isinstance(transform_subset, TransformSubset)

0 commit comments

Comments
 (0)