Skip to content

Commit a38e83c

Browse files
committed
[refactor] Fix SparseMatrixType --> spmatrix
1 parent 9a7ce79 commit a38e83c

File tree

4 files changed

+19
-26
lines changed

4 files changed

+19
-26
lines changed

autoPyTorch/data/base_feature_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
import pandas as pd
77

8+
from scipy.sparse import spmatrix
9+
810
from sklearn.base import BaseEstimator
911

10-
from autoPyTorch.utils.common import SparseMatrixType
1112
from autoPyTorch.utils.logging_ import PicklableClientLogger
1213

1314

14-
SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, SparseMatrixType]
15+
SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, spmatrix]
1516

1617

1718
class BaseFeatureValidator(BaseEstimator):

autoPyTorch/data/base_target_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
import pandas as pd
77

8+
from scipy.sparse import spmatrix
9+
810
from sklearn.base import BaseEstimator
911

10-
from autoPyTorch.utils.common import SparseMatrixType
1112
from autoPyTorch.utils.logging_ import PicklableClientLogger
1213

1314

14-
SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, SparseMatrixType]
15+
SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, spmatrix]
1516

1617

1718
class BaseTargetValidator(BaseEstimator):

autoPyTorch/data/tabular_target_validator.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
from pandas.api.types import is_numeric_dtype
77

8-
import scipy.sparse
8+
from scipy.sparse import issparse, spmatrix
99

1010
import sklearn.utils
1111
from sklearn import preprocessing
@@ -14,10 +14,10 @@
1414
from sklearn.utils.multiclass import type_of_target
1515

1616
from autoPyTorch.data.base_target_validator import BaseTargetValidator, SupportedTargetTypes
17-
from autoPyTorch.utils.common import SparseMatrixType
17+
from autoPyTorch.data.utils import ispandas
1818

1919

20-
ArrayType = Union[np.ndarray, SparseMatrixType]
20+
ArrayType = Union[np.ndarray, spmatrix]
2121

2222

2323
def _check_and_to_array(y: SupportedTargetTypes) -> ArrayType:
@@ -71,7 +71,7 @@ def _fit(
7171
return self
7272

7373
if y_test is not None:
74-
if hasattr(y_train, "iloc"):
74+
if ispandas(y_train):
7575
y_train = pd.concat([y_train, y_test], ignore_index=True, sort=False)
7676
elif isinstance(y_train, list):
7777
y_train = y_train + y_test
@@ -100,7 +100,7 @@ def _fit(
100100
if ndim > 1:
101101
self.encoder.fit(y_train)
102102
else:
103-
if hasattr(y_train, 'iloc'):
103+
if ispandas(y_train):
104104
y_train = cast(pd.DataFrame, y_train)
105105
self.encoder.fit(y_train.to_numpy().reshape(-1, 1))
106106
else:
@@ -131,7 +131,7 @@ def _transform_by_encoder(self, y: SupportedTargetTypes) -> np.ndarray:
131131
shape = np.shape(y)
132132
if len(shape) > 1:
133133
y = self.encoder.transform(y)
134-
elif hasattr(y, 'iloc'):
134+
elif ispandas(y):
135135
# The Ordinal encoder expects a 2 dimensional input.
136136
# The targets are 1 dimensional, so reshape to match the expected shape
137137
y = cast(pd.DataFrame, y)
@@ -192,7 +192,7 @@ def inverse_transform(self, y: SupportedTargetTypes) -> np.ndarray:
192192
y = self.encoder.inverse_transform(y)
193193
else:
194194
# The targets should be a flattened array, hence reshape with -1
195-
if hasattr(y, 'iloc'):
195+
if ispandas(y):
196196
y = cast(pd.DataFrame, y)
197197
y = self.encoder.inverse_transform(y.to_numpy().reshape(-1, 1)).reshape(-1)
198198
else:
@@ -216,7 +216,7 @@ def _check_data(self, y: SupportedTargetTypes) -> None:
216216

217217
if not isinstance(y, (np.ndarray, pd.DataFrame,
218218
List, pd.Series)) \
219-
and not scipy.sparse.issparse(y): # type: ignore[misc]
219+
and not issparse(y): # type: ignore[misc]
220220
raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
221221
" pd.Series, sparse data and Python Lists as targets, yet, "
222222
"the provided input is of type {}".format(
@@ -225,8 +225,8 @@ def _check_data(self, y: SupportedTargetTypes) -> None:
225225

226226
# Sparse data muss be numerical
227227
# Type ignore on attribute because sparse targets have a dtype
228-
if scipy.sparse.issparse(y) and not np.issubdtype(y.dtype.type, # type: ignore[union-attr]
229-
np.number):
228+
if issparse(y) and not np.issubdtype(y.dtype.type, # type: ignore[union-attr]
229+
np.number):
230230
raise ValueError("When providing a sparse matrix as targets, the only supported "
231231
"values are numerical. Please consider using a dense"
232232
" instead."
@@ -245,10 +245,10 @@ def _check_data(self, y: SupportedTargetTypes) -> None:
245245

246246
# No Nan is supported
247247
has_nan_values = False
248-
if hasattr(y, 'iloc'):
248+
if ispandas(y):
249249
has_nan_values = cast(pd.DataFrame, y).isnull().values.any()
250-
if scipy.sparse.issparse(y):
251-
y = cast(scipy.sparse.spmatrix, y)
250+
if issparse(y):
251+
y = cast(spmatrix, y)
252252
has_nan_values = not np.array_equal(y.data, y.data)
253253
else:
254254
# List and array like values are considered here

autoPyTorch/utils/common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,6 @@
2020
from torch.utils.data.dataloader import default_collate
2121

2222
HyperparameterValueType = Union[int, str, float]
23-
SparseMatrixType = Union[
24-
scipy.sparse.bsr_matrix,
25-
scipy.sparse.coo_matrix,
26-
scipy.sparse.csc_matrix,
27-
scipy.sparse.csr_matrix,
28-
scipy.sparse.dia_matrix,
29-
scipy.sparse.dok_matrix,
30-
scipy.sparse.lil_matrix,
31-
]
3223

3324

3425
class FitRequirement(NamedTuple):

0 commit comments

Comments
 (0)