Skip to content

Commit b836cf0

Browse files
committed
fix tests
1 parent e36ae4a commit b836cf0

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

autoPyTorch/data/tabular_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
dataset_compression: Optional[DatasetCompressionSpec] = None,
4949
seed: int = 42,
5050
) -> None:
51-
self._dataset_compression = dataset_compression
51+
self.dataset_compression = dataset_compression
5252
self._reduced_dtype: Optional[DatasetDTypeContainerType] = None
5353
self.is_classification = is_classification
5454
self.logger_port = logger_port
@@ -92,7 +92,7 @@ def _compress_dataset(
9292
"""
9393
is_dataframe = hasattr(X, 'iloc')
9494
is_reducible_type = isinstance(X, np.ndarray) or issparse(X) or is_dataframe
95-
if not is_reducible_type or self._dataset_compression is None:
95+
if not is_reducible_type or self.dataset_compression is None:
9696
return X, y
9797
elif self._reduced_dtype is not None:
9898
X = X.astype(self._reduced_dtype)
@@ -103,7 +103,7 @@ def _compress_dataset(
103103
y=y,
104104
is_classification=self.is_classification,
105105
random_state=self.seed,
106-
**self._dataset_compression # type: ignore [arg-type]
106+
**self.dataset_compression # type: ignore [arg-type]
107107
)
108108
self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype
109109
return X, y

test/test_data/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def test_unsupported_errors():
127127
['a', 'b', 'c', 'a', 'b', 'c'],
128128
['a', 'b', 'd', 'r', 'b', 'c']])
129129
with pytest.raises(ValueError, match=r'X.dtype = .*'):
130-
reduce_dataset_size_if_too_large(X, 0)
130+
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)
131131

132132
X = [[1, 2], [2, 3]]
133133
with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'):
134-
reduce_dataset_size_if_too_large(X, 0)
134+
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)

test/test_data/test_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_validation_unsupported():
163163
),
164164
indirect=True
165165
)
166-
def test_featurevalidator_reduce_precision(input_data_featuretest):
166+
def test_featurevalidator_dataset_compression(input_data_featuretest):
167167
n_samples = input_data_featuretest.shape[0]
168168
input_data_targets = np.random.random_sample((n_samples))
169169
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(

0 commit comments

Comments
 (0)