|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 |
|
| 5 | +from autoPyTorch.data.tabular_validator import TabularInputValidator |
| 6 | +from autoPyTorch.datasets.base_dataset import TransformSubset |
| 7 | +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes |
5 | 8 | from autoPyTorch.datasets.tabular_dataset import TabularDataset
|
6 | 9 | from autoPyTorch.utils.pipeline import get_dataset_requirements
|
7 | 10 |
|
@@ -46,3 +49,34 @@ def test_get_dataset_properties(backend, fit_dictionary_tabular):
|
46 | 49 | def test_not_supported():
|
47 | 50 | with pytest.raises(ValueError, match=r".*A feature validator is required to build.*"):
|
48 | 51 | TabularDataset(np.ones(10), np.ones(10))
|
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize('resampling_strategy', |
| 55 | + (HoldoutValTypes.holdout_validation, |
| 56 | + CrossValTypes.k_fold_cross_validation, |
| 57 | + NoResamplingStrategyTypes.no_resampling |
| 58 | + )) |
| 59 | +def test_get_dataset(resampling_strategy, n_samples): |
| 60 | + """ |
| 61 | + Checks the functionality of get_dataset function of the TabularDataset |
| 62 | + gives an error when trying to get training and validation subset |
| 63 | + """ |
| 64 | + X = np.zeros(shape=(n_samples, 4)) |
| 65 | + Y = np.ones(n_samples) |
| 66 | + validator = TabularInputValidator(is_classification=True) |
| 67 | + validator.fit(X, Y) |
| 68 | + dataset = TabularDataset( |
| 69 | + resampling_strategy=resampling_strategy, |
| 70 | + X=X, |
| 71 | + Y=Y, |
| 72 | + validator=validator |
| 73 | + ) |
| 74 | + transform_subset = dataset.get_dataset(split_id=0, train=True) |
| 75 | + assert isinstance(transform_subset, TransformSubset) |
| 76 | + |
| 77 | + if isinstance(resampling_strategy, NoResamplingStrategyTypes): |
| 78 | + with pytest.raises(ValueError): |
| 79 | + dataset.get_dataset(split_id=0, train=False) |
| 80 | + else: |
| 81 | + transform_subset = dataset.get_dataset(split_id=0, train=False) |
| 82 | + assert isinstance(transform_subset, TransformSubset) |
0 commit comments