Skip to content

Commit 6d9f99f

Browse files
ArlindKadraravinkohlinabenabe0928
authored
Enhancement for the tabular validator. (#291)
* Initial try at an enhancement for the tabular validator * Adding a few type annotations * Fixing bugs in implementation * Adding wrongly deleted code part during rebase * Fix bug in _get_args * Fix bug in _get_args * Addressing Shuhei's comments * Address Shuhei's comments * Refactoring code * Refactoring code * Typos fix and additional comments * Replace nan in categoricals with simple imputer * Remove unused function * add comment * Update autoPyTorch/data/tabular_feature_validator.py Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * Update autoPyTorch/data/tabular_feature_validator.py Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * Adding unit test for only nall columns in the tabular feature categorical evaluator * fix bug in remove all nan columns * Bug fix for making tests run by arlind * fix flake errors in feature validator * made typing code uniform * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * address comments from shuhei * address comments from shuhei (2) Co-authored-by: Ravin Kohli <kohliravin7@gmail.com> Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent f79a4fc commit 6d9f99f

File tree

4 files changed

+290
-200
lines changed

4 files changed

+290
-200
lines changed

autoPyTorch/data/base_feature_validator.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
import typing
2+
from typing import List, Optional, Set, Tuple, Union
33

44
import numpy as np
55

@@ -12,8 +12,8 @@
1212
from autoPyTorch.utils.logging_ import PicklableClientLogger
1313

1414

15-
SUPPORTED_FEAT_TYPES = typing.Union[
16-
typing.List,
15+
SUPPORTED_FEAT_TYPES = Union[
16+
List,
1717
pd.DataFrame,
1818
np.ndarray,
1919
scipy.sparse.bsr_matrix,
@@ -35,43 +35,44 @@ class BaseFeatureValidator(BaseEstimator):
3535
List of the column types found by this estimator during fit.
3636
data_type (str):
3737
Class name of the data type provided during fit.
38-
encoder (typing.Optional[BaseEstimator])
38+
encoder (Optional[BaseEstimator])
3939
Host a encoder object if the data requires transformation (for example,
4040
if provided a categorical column in a pandas DataFrame)
41-
enc_columns (typing.List[str])
41+
enc_columns (List[str])
4242
List of columns that were encoded.
4343
"""
4444
def __init__(self,
45-
logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger
46-
]] = None,
45+
logger: Optional[Union[PicklableClientLogger, logging.Logger
46+
]
47+
] = None,
4748
) -> None:
4849
# Register types to detect unsupported data format changes
49-
self.feat_type = None # type: typing.Optional[typing.List[str]]
50-
self.data_type = None # type: typing.Optional[type]
51-
self.dtypes = [] # type: typing.List[str]
52-
self.column_order = [] # type: typing.List[str]
50+
self.feat_type: Optional[List[str]] = None
51+
self.data_type: Optional[type] = None
52+
self.dtypes: List[str] = []
53+
self.column_order: List[str] = []
5354

54-
self.encoder = None # type: typing.Optional[BaseEstimator]
55-
self.enc_columns = [] # type: typing.List[str]
55+
self.encoder: Optional[BaseEstimator] = None
56+
self.enc_columns: List[str] = []
5657

57-
self.logger: typing.Union[
58+
self.logger: Union[
5859
PicklableClientLogger, logging.Logger
5960
] = logger if logger is not None else logging.getLogger(__name__)
6061

6162
# Required for dataset properties
62-
self.num_features = None # type: typing.Optional[int]
63-
self.categories = [] # type: typing.List[typing.List[int]]
64-
self.categorical_columns: typing.List[int] = []
65-
self.numerical_columns: typing.List[int] = []
66-
# column identifiers may be integers or strings
67-
self.null_columns: typing.Set[str] = set()
63+
self.num_features: Optional[int] = None
64+
self.categories: List[List[int]] = []
65+
self.categorical_columns: List[int] = []
66+
self.numerical_columns: List[int] = []
67+
68+
self.all_nan_columns: Optional[Set[Union[int, str]]] = None
6869

6970
self._is_fitted = False
7071

7172
def fit(
7273
self,
7374
X_train: SUPPORTED_FEAT_TYPES,
74-
X_test: typing.Optional[SUPPORTED_FEAT_TYPES] = None,
75+
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
7576
) -> BaseEstimator:
7677
"""
7778
Validates and fit a categorical encoder (if needed) to the features.
@@ -82,7 +83,7 @@ def fit(
8283
X_train (SUPPORTED_FEAT_TYPES):
8384
A set of features that are going to be validated (type and dimensionality
8485
checks) and a encoder fitted in the case the data needs encoding
85-
X_test (typing.Optional[SUPPORTED_FEAT_TYPES]):
86+
X_test (Optional[SUPPORTED_FEAT_TYPES]):
8687
A hold out set of data used for checking
8788
"""
8889

@@ -122,6 +123,7 @@ def _fit(
122123
self:
123124
The fitted base estimator
124125
"""
126+
125127
raise NotImplementedError()
126128

127129
def _check_data(
@@ -136,6 +138,7 @@ def _check_data(
136138
A set of features that are going to be validated (type and dimensionality
137139
checks) and a encoder fitted in the case the data needs encoding
138140
"""
141+
139142
raise NotImplementedError()
140143

141144
def transform(
@@ -152,4 +155,30 @@ def transform(
152155
np.ndarray:
153156
The transformed array
154157
"""
158+
159+
raise NotImplementedError()
160+
161+
def list_to_dataframe(
162+
self,
163+
X_train: SUPPORTED_FEAT_TYPES,
164+
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
165+
) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
166+
"""
167+
Converts a list to a pandas DataFrame. In this process, column types are inferred.
168+
169+
If test data is provided, we proactively match it to train data
170+
171+
Arguments:
172+
X_train (SUPPORTED_FEAT_TYPES):
173+
A set of features that are going to be validated (type and dimensionality
174+
checks) and a encoder fitted in the case the data needs encoding
175+
X_test (Optional[SUPPORTED_FEAT_TYPES]):
176+
A hold out set of data used for checking
177+
Returns:
178+
pd.DataFrame:
179+
transformed train data from list to pandas DataFrame
180+
pd.DataFrame:
181+
transformed test data from list to pandas DataFrame
182+
"""
183+
155184
raise NotImplementedError()

autoPyTorch/data/base_target_validator.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
import typing
2+
from typing import List, Optional, Union, cast
33

44
import numpy as np
55

@@ -12,8 +12,8 @@
1212
from autoPyTorch.utils.logging_ import PicklableClientLogger
1313

1414

15-
SUPPORTED_TARGET_TYPES = typing.Union[
16-
typing.List,
15+
SUPPORTED_TARGET_TYPES = Union[
16+
List,
1717
pd.Series,
1818
pd.DataFrame,
1919
np.ndarray,
@@ -35,39 +35,39 @@ class BaseTargetValidator(BaseEstimator):
3535
is_classification (bool):
3636
A bool that indicates if the validator should operate in classification mode.
3737
During classification, the targets are encoded.
38-
encoder (typing.Optional[BaseEstimator]):
38+
encoder (Optional[BaseEstimator]):
3939
Host a encoder object if the data requires transformation (for example,
4040
if provided a categorical column in a pandas DataFrame)
41-
enc_columns (typing.List[str])
41+
enc_columns (List[str])
4242
List of columns that where encoded
4343
"""
4444
def __init__(self,
4545
is_classification: bool = False,
46-
logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger
46+
logger: Optional[Union[PicklableClientLogger, logging.Logger
4747
]] = None,
4848
) -> None:
4949
self.is_classification = is_classification
5050

51-
self.data_type = None # type: typing.Optional[type]
51+
self.data_type: Optional[type] = None
5252

53-
self.encoder = None # type: typing.Optional[BaseEstimator]
53+
self.encoder: Optional[BaseEstimator] = None
5454

55-
self.out_dimensionality = None # type: typing.Optional[int]
56-
self.type_of_target = None # type: typing.Optional[str]
55+
self.out_dimensionality: Optional[int] = None
56+
self.type_of_target: Optional[str] = None
5757

58-
self.logger: typing.Union[
58+
self.logger: Union[
5959
PicklableClientLogger, logging.Logger
6060
] = logger if logger is not None else logging.getLogger(__name__)
6161

6262
# Store the dtype for remapping to correct type
63-
self.dtype = None # type: typing.Optional[type]
63+
self.dtype: Optional[type] = None
6464

6565
self._is_fitted = False
6666

6767
def fit(
6868
self,
6969
y_train: SUPPORTED_TARGET_TYPES,
70-
y_test: typing.Optional[SUPPORTED_TARGET_TYPES] = None,
70+
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
7171
) -> BaseEstimator:
7272
"""
7373
Validates and fit a categorical encoder (if needed) to the targets
@@ -76,7 +76,7 @@ def fit(
7676
Arguments:
7777
y_train (SUPPORTED_TARGET_TYPES)
7878
A set of targets set aside for training
79-
y_test (typing.Union[SUPPORTED_TARGET_TYPES])
79+
y_test (Union[SUPPORTED_TARGET_TYPES])
8080
A hold out set of data used of the targets. It is also used to fit the
8181
categories of the encoder.
8282
"""
@@ -95,8 +95,8 @@ def fit(
9595
np.shape(y_test)
9696
))
9797
if isinstance(y_train, pd.DataFrame):
98-
y_train = typing.cast(pd.DataFrame, y_train)
99-
y_test = typing.cast(pd.DataFrame, y_test)
98+
y_train = cast(pd.DataFrame, y_train)
99+
y_test = cast(pd.DataFrame, y_test)
100100
if y_train.columns.tolist() != y_test.columns.tolist():
101101
raise ValueError(
102102
"Train and test targets must both have the same columns, yet "
@@ -127,21 +127,21 @@ def fit(
127127
def _fit(
128128
self,
129129
y_train: SUPPORTED_TARGET_TYPES,
130-
y_test: typing.Optional[SUPPORTED_TARGET_TYPES] = None,
130+
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
131131
) -> BaseEstimator:
132132
"""
133133
Arguments:
134134
y_train (SUPPORTED_TARGET_TYPES)
135135
The labels of the current task. They are going to be encoded in case
136136
of classification
137-
y_test (typing.Optional[SUPPORTED_TARGET_TYPES])
137+
y_test (Optional[SUPPORTED_TARGET_TYPES])
138138
A holdout set of labels
139139
"""
140140
raise NotImplementedError()
141141

142142
def transform(
143143
self,
144-
y: typing.Union[SUPPORTED_TARGET_TYPES],
144+
y: Union[SUPPORTED_TARGET_TYPES],
145145
) -> np.ndarray:
146146
"""
147147
Arguments:
@@ -162,7 +162,7 @@ def inverse_transform(
162162
Revert any encoding transformation done on a target array
163163
164164
Arguments:
165-
y (typing.Union[np.ndarray, pd.DataFrame, pd.Series]):
165+
y (Union[np.ndarray, pd.DataFrame, pd.Series]):
166166
Target array to be transformed back to original form before encoding
167167
Returns:
168168
np.ndarray:

0 commit comments

Comments
 (0)