Skip to content

Commit c3b8844

Browse files
committed
[FIX] Enable preprocessing in reg_cocktails (#369)
* enable preprocessing and remove is_small_preprocess * address comments from shuhei and fix precommit checks * fix tests * fix precommit checks * add suggestions from shuhei for astype use * address speed issue when using object_dtype_mapping * make code more readable * improve documentation for base network embedding
1 parent 59b5830 commit c3b8844

33 files changed

+266
-770
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21-
CrossValTypes,
2221
ResamplingStrategies,
2322
)
2423
from autoPyTorch.datasets.tabular_dataset import TabularDataset

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21-
CrossValTypes,
2221
ResamplingStrategies,
2322
)
2423
from autoPyTorch.datasets.tabular_dataset import TabularDataset

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 81 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.exceptions import NotFittedError
1717
from sklearn.impute import SimpleImputer
1818
from sklearn.pipeline import make_pipeline
19-
from sklearn.preprocessing import OneHotEncoder, StandardScaler
19+
from sklearn.preprocessing import OrdinalEncoder
2020

2121
from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes
2222
from autoPyTorch.utils.common import ispandas
@@ -25,7 +25,6 @@
2525

2626
def _create_column_transformer(
2727
preprocessors: Dict[str, List[BaseEstimator]],
28-
numerical_columns: List[str],
2928
categorical_columns: List[str],
3029
) -> ColumnTransformer:
3130
"""
@@ -36,49 +35,36 @@ def _create_column_transformer(
3635
Args:
3736
preprocessors (Dict[str, List[BaseEstimator]]):
3837
Dictionary containing list of numerical and categorical preprocessors.
39-
numerical_columns (List[str]):
40-
List of names of numerical columns
4138
categorical_columns (List[str]):
4239
List of names of categorical columns
4340
4441
Returns:
4542
ColumnTransformer
4643
"""
4744

48-
numerical_pipeline = 'drop'
49-
categorical_pipeline = 'drop'
50-
if len(numerical_columns) > 0:
51-
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
52-
if len(categorical_columns) > 0:
53-
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
45+
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
5446

5547
return ColumnTransformer([
56-
('categorical_pipeline', categorical_pipeline, categorical_columns),
57-
('numerical_pipeline', numerical_pipeline, numerical_columns)],
58-
remainder='drop'
48+
('categorical_pipeline', categorical_pipeline, categorical_columns)],
49+
remainder='passthrough'
5950
)
6051

6152

6253
def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
6354
"""
6455
This function creates a Dictionary containing a list
6556
of numerical and categorical preprocessors
66-
6757
Returns:
6858
Dict[str, List[BaseEstimator]]
6959
"""
7060
preprocessors: Dict[str, List[BaseEstimator]] = dict()
7161

7262
# Categorical Preprocessors
73-
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
63+
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
64+
unknown_value=-1)
7465
categorical_imputer = SimpleImputer(strategy='constant', copy=False)
7566

76-
# Numerical Preprocessors
77-
numerical_imputer = SimpleImputer(strategy='median', copy=False)
78-
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)
79-
80-
preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
81-
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
67+
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]
8268

8369
return preprocessors
8470

@@ -176,7 +162,16 @@ def _fit(
176162
if ispandas(X) and not issparse(X):
177163
X = cast(pd.DataFrame, X)
178164

179-
self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])
165+
all_nan_columns = X.columns[X.isna().all()]
166+
for col in all_nan_columns:
167+
X[col] = pd.to_numeric(X[col])
168+
169+
# Handle objects if possible
170+
exist_object_columns = has_object_columns(X.dtypes.values)
171+
if exist_object_columns:
172+
X = self.infer_objects(X)
173+
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
174+
self.all_nan_columns = set(all_nan_columns)
180175

181176
self.transformed_columns, self.feat_types = self.get_columns_to_encode(X)
182177

@@ -188,18 +183,33 @@ def _fit(
188183
categorical_columns=self.transformed_columns,
189184
)
190185

191-
# Mypy redefinition
192-
assert self.column_transformer is not None
193-
self.column_transformer.fit(X)
186+
if len(self.enc_columns) > 0:
194187

195-
# The column transformer reorders the feature types
196-
# therefore, we need to change the order of columns as well
197-
# This means categorical columns are shifted to the left
188+
preprocessors = get_tabular_preprocessors()
189+
self.column_transformer = _create_column_transformer(
190+
preprocessors=preprocessors,
191+
categorical_columns=self.enc_columns,
192+
)
198193

199-
self.feat_types = sorted(
200-
self.feat_types,
201-
key=functools.cmp_to_key(self._comparator)
202-
)
194+
# Mypy redefinition
195+
assert self.column_transformer is not None
196+
self.column_transformer.fit(X)
197+
198+
# The column transformer moves categorical columns before all numerical columns
199+
# therefore, we need to sort categorical columns so that it complies this change
200+
201+
self.feat_types = sorted(
202+
self.feat_types,
203+
key=functools.cmp_to_key(self._comparator)
204+
)
205+
206+
encoded_categories = self.column_transformer.\
207+
named_transformers_['categorical_pipeline'].\
208+
named_steps['ordinalencoder'].categories_
209+
self.categories = [
210+
list(range(len(cat)))
211+
for cat in encoded_categories
212+
]
203213

204214
# differently to categorical_columns and numerical_columns,
205215
# this saves the index of the column.
@@ -279,6 +289,23 @@ def transform(
279289
if ispandas(X) and not issparse(X):
280290
X = cast(pd.DataFrame, X)
281291

292+
if self.all_nan_columns is None:
293+
raise ValueError('_fit must be called before calling transform')
294+
295+
for col in list(self.all_nan_columns):
296+
X[col] = np.nan
297+
X[col] = pd.to_numeric(X[col])
298+
299+
if len(self.categorical_columns) > 0:
300+
# when some categorical columns are not all nan in the training set
301+
# but they are all nan in the testing or validation set
302+
# we change those columns to `object` dtype
303+
# to ensure that these columns are changed to appropriate dtype
304+
# in self.infer_objects
305+
all_nan_cat_cols = set(X[self.enc_columns].columns[X[self.enc_columns].isna().all()])
306+
dtype_dict = {col: 'object' for col in self.enc_columns if col in all_nan_cat_cols}
307+
X = X.astype(dtype_dict)
308+
282309
# Check the data here so we catch problems on new test data
283310
self._check_data(X)
284311

@@ -287,11 +314,6 @@ def transform(
287314
# We need to convert the column in test data to
288315
# object otherwise the test column is interpreted as float
289316
if self.column_transformer is not None:
290-
if len(self.categorical_columns) > 0:
291-
categorical_columns = self.column_transformer.transformers_[0][-1]
292-
for column in categorical_columns:
293-
if X[column].isna().all():
294-
X[column] = X[column].astype('object')
295317
X = self.column_transformer.transform(X)
296318

297319
# Sparse related transformations
@@ -380,7 +402,6 @@ def _check_data(
380402
self.column_order = column_order
381403

382404
dtypes = [dtype.name for dtype in X.dtypes]
383-
384405
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
385406
if len(self.dtypes) == 0:
386407
self.dtypes = dtypes
@@ -448,7 +469,7 @@ def _validate_feat_types(self, X: pd.DataFrame) -> None:
448469
def _get_columns_to_encode(
449470
self,
450471
X: pd.DataFrame,
451-
) -> Tuple[List[str], List[str], List[str]]:
472+
) -> Tuple[List[str], List[str]]:
452473
"""
453474
Return the columns to be transformed as well as
454475
the type of feature for each column from a pandas dataframe.
@@ -478,8 +499,8 @@ def _get_columns_to_encode(
478499
# Also, register the feature types for the estimator
479500
feat_types = []
480501

481-
# Make sure each column is a valid type
482-
for column in X.columns:
502+
# Make sure each column is a valid type
503+
for i, column in enumerate(X.columns):
483504
if self.all_nan_columns is not None and column in self.all_nan_columns:
484505
continue
485506
column_dtype = self.dtypes[i]
@@ -592,22 +613,26 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
592613
pd.DataFrame
593614
"""
594615
if hasattr(self, 'object_dtype_mapping'):
595-
# Mypy does not process the has attr. This dict is defined below
596-
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
597-
# honor the training data types
598-
try:
599-
X[key] = X[key].astype(dtype.name)
600-
except Exception as e:
601-
# Try inference if possible
602-
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
603-
pass
616+
# honor the training data types
617+
try:
618+
# Mypy does not process the has attr.
619+
X = X.astype(self.object_dtype_mapping) # type: ignore[has-type]
620+
except Exception as e:
621+
# Try inference if possible
622+
self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type]
623+
f'{self.object_dtype_mapping} caused the exception {e}')
624+
pass
604625
else:
605-
# Calling for the first time to infer the categories
606-
X = X.infer_objects()
607-
for column, data_type in zip(X.columns, X.dtypes):
608-
if not is_numeric_dtype(data_type):
609-
X[column] = X[column].astype('category')
610-
626+
if len(self.dtypes) != 0:
627+
# when train data has no object dtype, but test does
628+
# we prioritise the datatype given in training data
629+
dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)}
630+
X = X.astype(dtype_dict)
631+
else:
632+
# Calling for the first time to infer the categories
633+
X = X.infer_objects()
634+
dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)}
635+
X = X.astype(dtype_dict)
611636
# only numerical attributes and categories
612637
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}
613638

autoPyTorch/datasets/base_dataset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def __init__(
155155
self.holdout_validators: Dict[str, HoldOutFunc] = {}
156156
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
157157
self.random_state = np.random.RandomState(seed=seed)
158-
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
159158
self.shuffle = shuffle
160159
self.resampling_strategy = resampling_strategy
161160
self.resampling_strategy_args = resampling_strategy_args
@@ -165,10 +164,6 @@ def __init__(
165164
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
166165
self.output_shape, self.output_type = _get_output_properties(self.train_tensors)
167166

168-
# TODO: Look for a criteria to define small enough to preprocess
169-
# False for the regularization cocktails initially
170-
self.is_small_preprocess = False
171-
172167
# Make sure cross validation splits are created once
173168
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
174169
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)

0 commit comments

Comments
 (0)