Skip to content

Commit a5e8ce6

Browse files
ArlindKadraravinkohli
authored andcommitted
Fixing issues with imbalanced datasets (#197)
* adding missing method from base_feature_validator * First try at a fix, removing redundant code * Fix bug * Updating unit test typo, fixing bug where the data type was not checked because X was a numpy array at the time of checking * Fixing flake 8 failing * Bug fix, implementation update for imbalanced datasets and unit tests to check the implementation * flake8 fix * Bug fix * Making the conversion to dataframe in the unit tests consistent with what happens at the validator, so the types do not change * flake8 fix * Addressing Ravin's comments
1 parent 147cf20 commit a5e8ce6

File tree

4 files changed

+166
-38
lines changed

4 files changed

+166
-38
lines changed

autoPyTorch/data/base_feature_validator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,20 @@ def _fit(
111111
"""
112112
raise NotImplementedError()
113113

114+
def _check_data(
115+
self,
116+
X: SUPPORTED_FEAT_TYPES,
117+
) -> None:
118+
"""
119+
Feature dimensionality and data type checks
120+
121+
Arguments:
122+
X (SUPPORTED_FEAT_TYPES):
123+
A set of features that are going to be validated (type and dimensionality
124+
checks) and a encoder fitted in the case the data needs encoding
125+
"""
126+
raise NotImplementedError()
127+
114128
def transform(
115129
self,
116130
X: SupportedFeatTypes,

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,13 @@ def _fit(
156156
# with nan values.
157157
# Columns that are completely made of NaN values are provided to the pipeline
158158
# so that later stages decide how to handle them
159+
160+
# Clear whatever null column markers we had previously
161+
self.null_columns.clear()
159162
if np.any(pd.isnull(X)):
160163
for column in X.columns:
161164
if X[column].isna().all():
165+
self.null_columns.add(column)
162166
X[column] = pd.to_numeric(X[column])
163167
# Also note this change in self.dtypes
164168
if len(self.dtypes) != 0:
@@ -167,9 +171,8 @@ def _fit(
167171
if not X.select_dtypes(include='object').empty:
168172
X = self.infer_objects(X)
169173

170-
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
171-
172-
assert self.feat_type is not None
174+
self._check_data(X)
175+
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
173176

174177
if len(self.transformed_columns) > 0:
175178

@@ -238,30 +241,38 @@ def transform(
238241
if isinstance(X, np.ndarray):
239242
X = self.numpy_array_to_pandas(X)
240243

241-
if ispandas(X) and not issparse(X):
242-
if np.any(pd.isnull(X)):
243-
for column in X.columns:
244-
if X[column].isna().all():
245-
X[column] = pd.to_numeric(X[column])
244+
if hasattr(X, "iloc") and not issparse(X):
245+
X = cast(pd.DataFrame, X)
246+
# If we had null columns in our fit call and we made them numeric, then:
247+
# - If the columns are null even in transform, apply the same procedure.
248+
# - Otherwise, substitute the values with np.NaN and then make the columns numeric.
249+
# If the column is null here, but it was not in fit, it does not matter.
250+
for column in self.null_columns:
251+
# The column is not null, make it null since it was null in fit.
252+
if not X[column].isna().all():
253+
X[column] = np.NaN
254+
X[column] = pd.to_numeric(X[column])
255+
256+
# for the test set, if we have columns with only null values
257+
# they will probably have a numeric type. If these columns were not
258+
# with only null values in the train set, they should be converted
259+
# to the type that they had during fitting.
260+
for column in X.columns:
261+
if X[column].isna().all():
262+
X[column] = X[column].astype(self.dtypes[list(X.columns).index(column)])
246263

247264
# Also remove the object dtype for new data
248265
if not X.select_dtypes(include='object').empty:
249266
X = self.infer_objects(X)
250267

251268
# Check the data here so we catch problems on new test data
252269
self._check_data(X)
270+
# We also need to fillna on the transformation
271+
# in case test data is provided
272+
X = self.impute_nan_in_categories(X)
253273

254-
# Pandas related transformations
255-
if ispandas(X) and self.column_transformer is not None:
256-
if np.any(pd.isnull(X)):
257-
# After above check it means that if there is a NaN
258-
# the whole column must be NaN
259-
# Make sure it is numerical and let the pipeline handle it
260-
for column in X.columns:
261-
if X[column].isna().all():
262-
X[column] = pd.to_numeric(X[column])
263-
264-
X = self.column_transformer.transform(X)
274+
if self.encoder is not None:
275+
X = self.encoder.transform(X)
265276

266277
# Sparse related transformations
267278
# Not all sparse format support index sorting
@@ -488,7 +499,7 @@ def numpy_array_to_pandas(
488499
Returns:
489500
pd.DataFrame
490501
"""
491-
return pd.DataFrame(X).infer_objects().convert_dtypes()
502+
return pd.DataFrame(X).convert_dtypes()
492503

493504
def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
494505
"""
@@ -506,18 +517,13 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
506517
if hasattr(self, 'object_dtype_mapping'):
507518
# Mypy does not process the has attr. This dict is defined below
508519
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
509-
if 'int' in dtype.name:
510-
# In the case train data was interpreted as int
511-
# and test data was interpreted as float, because of 0.0
512-
# for example, honor training data
513-
X[key] = X[key].applymap(np.int64)
514-
else:
515-
try:
516-
X[key] = X[key].astype(dtype.name)
517-
except Exception as e:
518-
# Try inference if possible
519-
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
520-
pass
520+
# honor the training data types
521+
try:
522+
X[key] = X[key].astype(dtype.name)
523+
except Exception as e:
524+
# Try inference if possible
525+
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
526+
pass
521527
else:
522528
X = X.infer_objects()
523529
for column in X.columns:

test/test_data/test_feature_validator.py

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import copy
1+
import copy
22
import functools
33

44
import numpy as np
@@ -139,9 +139,9 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
139139
if isinstance(input_data_featuretest, pd.DataFrame):
140140
pytest.skip("Column order change in pandas is not supported")
141141
elif isinstance(input_data_featuretest, np.ndarray):
142-
complementary_type = pd.DataFrame(input_data_featuretest)
142+
complementary_type = validator.numpy_array_to_pandas(input_data_featuretest)
143143
elif isinstance(input_data_featuretest, list):
144-
complementary_type = pd.DataFrame(input_data_featuretest)
144+
complementary_type, _ = validator.list_to_dataframe(input_data_featuretest)
145145
elif sparse.issparse(input_data_featuretest):
146146
complementary_type = sparse.csr_matrix(input_data_featuretest.todense())
147147
else:
@@ -331,8 +331,11 @@ def test_unknown_encode_value():
331331
)
332332
@pytest.mark.parametrize('train_data_type', ('numpy', 'pandas', 'list'))
333333
@pytest.mark.parametrize('test_data_type', ('numpy', 'pandas', 'list'))
334-
def test_featurevalidator_new_data_after_fit(openml_id,
335-
train_data_type, test_data_type):
334+
def test_feature_validator_new_data_after_fit(
335+
openml_id,
336+
train_data_type,
337+
test_data_type,
338+
):
336339

337340
# List is currently not supported as infer_objects
338341
# cast list objects to type objects
@@ -406,3 +409,109 @@ def test_comparator():
406409
key=functools.cmp_to_key(validator._comparator)
407410
)
408411
assert ans == feat_type
412+
413+
414+
# Actual checks for the features
415+
@pytest.mark.parametrize(
416+
'input_data_featuretest',
417+
(
418+
'numpy_numericalonly_nonan',
419+
'numpy_numericalonly_nan',
420+
'numpy_mixed_nan',
421+
'pandas_numericalonly_nan',
422+
'sparse_bsr_nonan',
423+
'sparse_bsr_nan',
424+
'sparse_coo_nonan',
425+
'sparse_coo_nan',
426+
'sparse_csc_nonan',
427+
'sparse_csc_nan',
428+
'sparse_csr_nonan',
429+
'sparse_csr_nan',
430+
'sparse_dia_nonan',
431+
'sparse_dia_nan',
432+
'sparse_dok_nonan',
433+
'sparse_dok_nan',
434+
'openml_40981', # Australian
435+
),
436+
indirect=True
437+
)
438+
def test_featurevalidator_reduce_precision(input_data_featuretest):
439+
X_train, X_test = sklearn.model_selection.train_test_split(
440+
input_data_featuretest, test_size=0.1, random_state=1)
441+
validator = TabularFeatureValidator(dataset_compression={'memory_allocation': 0, 'methods': ['precision']})
442+
validator.fit(X_train=X_train)
443+
transformed_X_train = validator.transform(X_train.copy())
444+
445+
assert validator._reduced_dtype is not None
446+
assert megabytes(transformed_X_train) < megabytes(X_train)
447+
448+
transformed_X_test = validator.transform(X_test.copy())
449+
assert megabytes(transformed_X_test) < megabytes(X_test)
450+
if hasattr(transformed_X_train, 'iloc'):
451+
assert all(transformed_X_train.dtypes == transformed_X_test.dtypes)
452+
assert all(transformed_X_train.dtypes == validator._precision)
453+
else:
454+
assert transformed_X_train.dtype == transformed_X_test.dtype
455+
assert transformed_X_test.dtype == validator._reduced_dtype
456+
457+
458+
def test_feature_validator_imbalanced_data():
459+
460+
# Null columns in the train split but not necessarily in the test split
461+
train_features = {
462+
'A': [np.NaN, np.NaN, np.NaN],
463+
'B': [1, 2, 3],
464+
'C': [np.NaN, np.NaN, np.NaN],
465+
'D': [np.NaN, np.NaN, np.NaN],
466+
}
467+
test_features = {
468+
'A': [3, 4, 5],
469+
'B': [6, 5, 7],
470+
'C': [np.NaN, np.NaN, np.NaN],
471+
'D': ['Blue', np.NaN, np.NaN],
472+
}
473+
474+
X_train = pd.DataFrame.from_dict(train_features)
475+
X_test = pd.DataFrame.from_dict(test_features)
476+
validator = TabularFeatureValidator()
477+
validator.fit(X_train)
478+
479+
train_feature_types = copy.deepcopy(validator.feat_type)
480+
assert train_feature_types == ['numerical', 'numerical', 'numerical', 'numerical']
481+
# validator will throw an error if the column types are not the same
482+
transformed_X_test = validator.transform(X_test)
483+
transformed_X_test = pd.DataFrame(transformed_X_test)
484+
null_columns = []
485+
for column in transformed_X_test.columns:
486+
if transformed_X_test[column].isna().all():
487+
null_columns.append(column)
488+
assert null_columns == [0, 2, 3]
489+
490+
# Columns with not all null values in the train split and
491+
# completely null on the test split.
492+
train_features = {
493+
'A': [np.NaN, np.NaN, 4],
494+
'B': [1, 2, 3],
495+
'C': ['Blue', np.NaN, np.NaN],
496+
}
497+
test_features = {
498+
'A': [np.NaN, np.NaN, np.NaN],
499+
'B': [6, 5, 7],
500+
'C': [np.NaN, np.NaN, np.NaN],
501+
}
502+
503+
X_train = pd.DataFrame.from_dict(train_features)
504+
X_test = pd.DataFrame.from_dict(test_features)
505+
validator = TabularFeatureValidator()
506+
validator.fit(X_train)
507+
train_feature_types = copy.deepcopy(validator.feat_type)
508+
assert train_feature_types == ['categorical', 'numerical', 'numerical']
509+
510+
transformed_X_test = validator.transform(X_test)
511+
transformed_X_test = pd.DataFrame(transformed_X_test)
512+
null_columns = []
513+
for column in transformed_X_test.columns:
514+
if transformed_X_test[column].isna().all():
515+
null_columns.append(column)
516+
517+
assert null_columns == [1]

test/test_data/test_validation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_data_validation_for_classification(openmlid, as_frame):
3232
x, y, test_size=0.33, random_state=0)
3333

3434
validator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
35-
3635
X_train_t, y_train_t = validator.transform(X_train, y_train)
3736
assert np.shape(X_train) == np.shape(X_train_t)
3837

0 commit comments

Comments
 (0)