Skip to content

Commit 6cbe591

Browse files
committed
[FIX] Enable preprocessing in reg_cocktails (#369)
* enable preprocessing and remove is_small_preprocess * address comments from shuhei and fix precommit checks * fix tests * fix precommit checks * add suggestions from shuhei for astype use * address speed issue when using object_dtype_mapping * make code more readable * improve documentation for base network embedding
1 parent abd1588 commit 6cbe591

34 files changed

+172
-793
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1818
from autoPyTorch.datasets.resampling_strategy import (
1919
HoldoutValTypes,
20-
CrossValTypes,
2120
ResamplingStrategies,
2221
)
2322
from autoPyTorch.datasets.tabular_dataset import TabularDataset

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1818
from autoPyTorch.datasets.resampling_strategy import (
1919
HoldoutValTypes,
20-
CrossValTypes,
2120
ResamplingStrategies,
2221
)
2322
from autoPyTorch.datasets.tabular_dataset import TabularDataset

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.exceptions import NotFittedError
1616
from sklearn.impute import SimpleImputer
1717
from sklearn.pipeline import make_pipeline
18-
from sklearn.preprocessing import OneHotEncoder, StandardScaler
18+
from sklearn.preprocessing import OrdinalEncoder
1919

2020
from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes
2121
from autoPyTorch.data.utils import (
@@ -28,7 +28,6 @@
2828

2929
def _create_column_transformer(
3030
preprocessors: Dict[str, List[BaseEstimator]],
31-
numerical_columns: List[str],
3231
categorical_columns: List[str],
3332
) -> ColumnTransformer:
3433
"""
@@ -39,49 +38,36 @@ def _create_column_transformer(
3938
Args:
4039
preprocessors (Dict[str, List[BaseEstimator]]):
4140
Dictionary containing list of numerical and categorical preprocessors.
42-
numerical_columns (List[str]):
43-
List of names of numerical columns
4441
categorical_columns (List[str]):
4542
List of names of categorical columns
4643
4744
Returns:
4845
ColumnTransformer
4946
"""
5047

51-
numerical_pipeline = 'drop'
52-
categorical_pipeline = 'drop'
53-
if len(numerical_columns) > 0:
54-
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
55-
if len(categorical_columns) > 0:
56-
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
48+
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
5749

5850
return ColumnTransformer([
59-
('categorical_pipeline', categorical_pipeline, categorical_columns),
60-
('numerical_pipeline', numerical_pipeline, numerical_columns)],
61-
remainder='drop'
51+
('categorical_pipeline', categorical_pipeline, categorical_columns)],
52+
remainder='passthrough'
6253
)
6354

6455

6556
def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
6657
"""
6758
This function creates a Dictionary containing a list
6859
of numerical and categorical preprocessors
69-
7060
Returns:
7161
Dict[str, List[BaseEstimator]]
7262
"""
7363
preprocessors: Dict[str, List[BaseEstimator]] = dict()
7464

7565
# Categorical Preprocessors
76-
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
66+
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
67+
unknown_value=-1)
7768
categorical_imputer = SimpleImputer(strategy='constant', copy=False)
7869

79-
# Numerical Preprocessors
80-
numerical_imputer = SimpleImputer(strategy='median', copy=False)
81-
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)
82-
83-
preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
84-
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
70+
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]
8571

8672
return preprocessors
8773

@@ -176,31 +162,47 @@ def _fit(
176162
if hasattr(X, "iloc") and not issparse(X):
177163
X = cast(pd.DataFrame, X)
178164

179-
self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])
165+
all_nan_columns = X.columns[X.isna().all()]
166+
for col in all_nan_columns:
167+
X[col] = pd.to_numeric(X[col])
168+
169+
# Handle objects if possible
170+
exist_object_columns = has_object_columns(X.dtypes.values)
171+
if exist_object_columns:
172+
X = self.infer_objects(X)
180173

181-
categorical_columns, numerical_columns, feat_type = self._get_columns_info(X)
174+
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
175+
self.all_nan_columns = set(all_nan_columns)
182176

183-
self.enc_columns = categorical_columns
177+
self.enc_columns, self.feat_type = self._get_columns_info(X)
184178

185-
preprocessors = get_tabular_preprocessors()
186-
self.column_transformer = _create_column_transformer(
187-
preprocessors=preprocessors,
188-
numerical_columns=numerical_columns,
189-
categorical_columns=categorical_columns,
190-
)
179+
if len(self.enc_columns) > 0:
191180

192-
# Mypy redefinition
193-
assert self.column_transformer is not None
194-
self.column_transformer.fit(X)
181+
preprocessors = get_tabular_preprocessors()
182+
self.column_transformer = _create_column_transformer(
183+
preprocessors=preprocessors,
184+
categorical_columns=self.enc_columns,
185+
)
195186

196-
# The column transformer reorders the feature types
197-
# therefore, we need to change the order of columns as well
198-
# This means categorical columns are shifted to the left
187+
# Mypy redefinition
188+
assert self.column_transformer is not None
189+
self.column_transformer.fit(X)
199190

200-
self.feat_type = sorted(
201-
feat_type,
202-
key=functools.cmp_to_key(self._comparator)
203-
)
191+
# The column transformer moves categorical columns before all numerical columns
192+
# therefore, we need to sort categorical columns so that it complies this change
193+
194+
self.feat_type = sorted(
195+
self.feat_type,
196+
key=functools.cmp_to_key(self._comparator)
197+
)
198+
199+
encoded_categories = self.column_transformer.\
200+
named_transformers_['categorical_pipeline'].\
201+
named_steps['ordinalencoder'].categories_
202+
self.categories = [
203+
list(range(len(cat)))
204+
for cat in encoded_categories
205+
]
204206

205207
# differently to categorical_columns and numerical_columns,
206208
# this saves the index of the column.
@@ -280,6 +282,23 @@ def transform(
280282
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
281283
X = cast(Type[pd.DataFrame], X)
282284

285+
if self.all_nan_columns is None:
286+
raise ValueError('_fit must be called before calling transform')
287+
288+
for col in list(self.all_nan_columns):
289+
X[col] = np.nan
290+
X[col] = pd.to_numeric(X[col])
291+
292+
if len(self.categorical_columns) > 0:
293+
# when some categorical columns are not all nan in the training set
294+
# but they are all nan in the testing or validation set
295+
# we change those columns to `object` dtype
296+
# to ensure that these columns are changed to appropriate dtype
297+
# in self.infer_objects
298+
all_nan_cat_cols = set(X[self.enc_columns].columns[X[self.enc_columns].isna().all()])
299+
dtype_dict = {col: 'object' for col in self.enc_columns if col in all_nan_cat_cols}
300+
X = X.astype(dtype_dict)
301+
283302
# Check the data here so we catch problems on new test data
284303
self._check_data(X)
285304

@@ -288,11 +307,6 @@ def transform(
288307
# We need to convert the column in test data to
289308
# object otherwise the test column is interpreted as float
290309
if self.column_transformer is not None:
291-
if len(self.categorical_columns) > 0:
292-
categorical_columns = self.column_transformer.transformers_[0][-1]
293-
for column in categorical_columns:
294-
if X[column].isna().all():
295-
X[column] = X[column].astype('object')
296310
X = self.column_transformer.transform(X)
297311

298312
# Sparse related transformations
@@ -407,7 +421,6 @@ def _check_data(
407421
self.column_order = column_order
408422

409423
dtypes = [dtype.name for dtype in X.dtypes]
410-
411424
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
412425
if len(self.dtypes) == 0:
413426
self.dtypes = dtypes
@@ -419,7 +432,7 @@ def _check_data(
419432
def _get_columns_info(
420433
self,
421434
X: pd.DataFrame,
422-
) -> Tuple[List[str], List[str], List[str]]:
435+
) -> Tuple[List[str], List[str]]:
423436
"""
424437
Return the columns to be encoded from a pandas dataframe
425438
@@ -438,15 +451,12 @@ def _get_columns_info(
438451
"""
439452

440453
# Register if a column needs encoding
441-
numerical_columns = []
442454
categorical_columns = []
443455
# Also, register the feature types for the estimator
444456
feat_type = []
445457

446458
# Make sure each column is a valid type
447459
for i, column in enumerate(X.columns):
448-
if self.all_nan_columns is not None and column in self.all_nan_columns:
449-
continue
450460
column_dtype = self.dtypes[i]
451461
err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \
452462
"but input column {} has an invalid type `{}`.".format(column, column_dtype)
@@ -457,7 +467,6 @@ def _get_columns_info(
457467
# TypeError: data type not understood in certain pandas types
458468
elif is_numeric_dtype(column_dtype):
459469
feat_type.append('numerical')
460-
numerical_columns.append(column)
461470
elif column_dtype == 'object':
462471
# TODO verify how would this happen when we always convert the object dtypes to category
463472
raise TypeError(
@@ -483,7 +492,7 @@ def _get_columns_info(
483492
"before feeding it to AutoPyTorch.".format(err_msg)
484493
)
485494

486-
return categorical_columns, numerical_columns, feat_type
495+
return categorical_columns, feat_type
487496

488497
def list_to_pandas(
489498
self,
@@ -553,22 +562,26 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
553562
pd.DataFrame
554563
"""
555564
if hasattr(self, 'object_dtype_mapping'):
556-
# Mypy does not process the has attr. This dict is defined below
557-
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
558-
# honor the training data types
559-
try:
560-
X[key] = X[key].astype(dtype.name)
561-
except Exception as e:
562-
# Try inference if possible
563-
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
564-
pass
565+
# honor the training data types
566+
try:
567+
# Mypy does not process the has attr.
568+
X = X.astype(self.object_dtype_mapping) # type: ignore[has-type]
569+
except Exception as e:
570+
# Try inference if possible
571+
self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type]
572+
f'{self.object_dtype_mapping} caused the exception {e}')
573+
pass
565574
else:
566-
# Calling for the first time to infer the categories
567-
X = X.infer_objects()
568-
for column, data_type in zip(X.columns, X.dtypes):
569-
if not is_numeric_dtype(data_type):
570-
X[column] = X[column].astype('category')
571-
575+
if len(self.dtypes) != 0:
576+
# when train data has no object dtype, but test does
577+
# we prioritise the datatype given in training data
578+
dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)}
579+
X = X.astype(dtype_dict)
580+
else:
581+
# Calling for the first time to infer the categories
582+
X = X.infer_objects()
583+
dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)}
584+
X = X.astype(dtype_dict)
572585
# only numerical attributes and categories
573586
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}
574587

autoPyTorch/datasets/base_dataset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def __init__(
155155
self.holdout_validators: Dict[str, HoldOutFunc] = {}
156156
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
157157
self.random_state = np.random.RandomState(seed=seed)
158-
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
159158
self.shuffle = shuffle
160159
self.resampling_strategy = resampling_strategy
161160
self.resampling_strategy_args = resampling_strategy_args
@@ -165,10 +164,6 @@ def __init__(
165164
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
166165
self.output_shape, self.output_type = _get_output_properties(self.train_tensors)
167166

168-
# TODO: Look for a criteria to define small enough to preprocess
169-
# False for the regularization cocktails initially
170-
self.is_small_preprocess = False
171-
172167
# Make sure cross validation splits are created once
173168
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
174169
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,6 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3939
...
4040

4141

42-
class NoResamplingFunc(Protocol):
43-
def __call__(self,
44-
random_state: np.random.RandomState,
45-
indices: np.ndarray) -> np.ndarray:
46-
...
47-
48-
4942
class CrossValTypes(IntEnum):
5043
"""The type of cross validation
5144

0 commit comments

Comments
 (0)