Skip to content

Commit ef1057a

Browse files
[FIX] Passing checks (#298)
* Initial fix for all tests passing locally py=3.8 * fix bug in tests * fix bug in test for data * debugging error in dummy forward pass * debug try -2 * catch runtime error in ci * catch runtime error in ci * add better debug test setup * debug some more * run this test only * remove sum backward * remove inplace in inception block * undo silly change * Enable all tests * fix flake * fix bug in test setup * remove anamoly detection * minor changes to comments * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * Address comments from Shuhei * revert change leading to bug * fix flake * change comment position in feature validator * Add documentation for _is_datasets_consistent * address comments from arlind * case when all nans in test Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent 43d4639 commit ef1057a

File tree

19 files changed

+132
-120
lines changed

19 files changed

+132
-120
lines changed

autoPyTorch/api/base_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,7 @@ def fit_ensemble(
13861386
Args:
13871387
optimize_metric (str): name of the metric that is used to
13881388
evaluate a pipeline. if not specified, value passed to search will be used
1389-
precision (int), (default=32): Numeric precision used when loading
1389+
precision (Optional[int]): Numeric precision used when loading
13901390
ensemble data. Can be either 16, 32 or 64.
13911391
ensemble_nbest (Optional[int]):
13921392
only consider the ensemble_nbest models to build the ensemble.
@@ -1429,6 +1429,7 @@ def fit_ensemble(
14291429
"Please call the `search()` method of {} prior to "
14301430
"fit_ensemble().".format(self.__class__.__name__))
14311431

1432+
precision = precision if precision is not None else self.precision
14321433
if precision not in [16, 32, 64]:
14331434
raise ValueError("precision must be one of 16, 32, 64 but got {}".format(precision))
14341435

@@ -1479,7 +1480,7 @@ def fit_ensemble(
14791480
manager = self._init_ensemble_builder(
14801481
time_left_for_ensembles=time_left_for_ensemble,
14811482
optimize_metric=self.opt_metric if optimize_metric is None else optimize_metric,
1482-
precision=self.precision if precision is None else precision,
1483+
precision=precision,
14831484
ensemble_size=ensemble_size,
14841485
ensemble_nbest=ensemble_nbest,
14851486
)

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import Dict, List, Optional, Tuple, cast
2+
from typing import Dict, List, Optional, Tuple, Union, cast
33

44
import numpy as np
55

@@ -100,6 +100,7 @@ def _comparator(cmp1: str, cmp2: str) -> int:
100100
if cmp1 not in choices or cmp2 not in choices:
101101
raise ValueError('The comparator for the column order only accepts {}, '
102102
'but got {} and {}'.format(choices, cmp1, cmp2))
103+
103104
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
104105
return idx1 - idx2
105106

@@ -246,13 +247,12 @@ def transform(
246247
# having a value for a categorical column.
247248
# We need to convert the column in test data to
248249
# object otherwise the test column is interpreted as float
249-
if len(self.categorical_columns) > 0:
250-
categorical_columns = self.column_transformer.transformers_[0][-1]
251-
for column in categorical_columns:
252-
if X[column].isna().all():
253-
X[column] = X[column].astype('object')
254-
255250
if self.column_transformer is not None:
251+
if len(self.categorical_columns) > 0:
252+
categorical_columns = self.column_transformer.transformers_[0][-1]
253+
for column in categorical_columns:
254+
if X[column].isna().all():
255+
X[column] = X[column].astype('object')
256256
X = self.column_transformer.transform(X)
257257

258258
# Sparse related transformations
@@ -337,16 +337,10 @@ def _check_data(
337337

338338
dtypes = [dtype.name for dtype in X.dtypes]
339339

340-
dtypes_diff = [s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]
340+
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
341341
if len(self.dtypes) == 0:
342342
self.dtypes = dtypes
343-
elif (
344-
any(dtypes_diff) # the dtypes of some columns are different in train and test dataset
345-
and self.all_nan_columns is not None # Ignore all_nan_columns is None
346-
and len(set(X.columns[dtypes_diff]).difference(self.all_nan_columns)) != 0
347-
):
348-
# The dtypes can be different if and only if the column belongs
349-
# to all_nan_columns as these columns would be imputed.
343+
elif not self._is_datasets_consistent(diff_cols, X):
350344
raise ValueError("The dtype of the features must not be changed after fit(), but"
351345
" the dtypes of some columns are different between training ({}) and"
352346
" test ({}) datasets.".format(self.dtypes, dtypes))
@@ -508,6 +502,33 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
508502

509503
return X
510504

505+
def _is_datasets_consistent(self, diff_cols: List[Union[int, str]], X: pd.DataFrame) -> bool:
506+
"""
507+
Check the consistency of dtypes between training and test datasets.
508+
The dtypes can be different if the column belongs to `self.all_nan_columns`
509+
(list of column names with all nans in training data) or if the column is
510+
all nan as these columns would be imputed.
511+
512+
Args:
513+
diff_cols (List[bool]):
514+
The column labels that have different dtypes.
515+
X (pd.DataFrame):
516+
A validation or test dataset to be compared with the training dataset
517+
Returns:
518+
_ (bool): Whether the training and test datasets are consistent.
519+
"""
520+
if self.all_nan_columns is None:
521+
if len(diff_cols) == 0:
522+
return True
523+
else:
524+
return all(X[diff_cols].isna().all())
525+
526+
# dtype is different ==> the column in at least either of train or test datasets must be all NaN
527+
# inconsistent <==> dtype is different and the col in both train and test is not all NaN
528+
inconsistent_cols = list(set(diff_cols) - self.all_nan_columns)
529+
530+
return len(inconsistent_cols) == 0 or all(X[inconsistent_cols].isna().all())
531+
511532

512533
def has_object_columns(
513534
feature_types: pd.Series,

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
3939
Returns:
4040
(Dict[str, Any]): the updated 'X' dictionary
4141
"""
42-
X.update({'encoder': self.preprocessor})
42+
# X.update({'encoder': self.preprocessor})
4343
return X
4444

4545
@staticmethod

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
4242
Returns:
4343
np.ndarray: Transformed features
4444
"""
45-
X.update({'scaler': self.preprocessor})
45+
# X.update({'scaler': self.preprocessor})
4646
return X
4747

4848
@staticmethod

autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, n_res_inputs: int, n_outputs: int):
7878
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
7979
shortcut = self.shortcut(res)
8080
shortcut = self.bn(shortcut)
81-
x += shortcut
81+
x = x + shortcut
8282
return torch.relu(x)
8383

8484

autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
2121

2222
self.embedding = self.build_embedding(
2323
num_input_features=num_input_features,
24-
num_numerical_features=num_numerical_columns)
24+
num_numerical_features=num_numerical_columns) # type: ignore[arg-type]
2525
return self
2626

2727
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,7 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
109109
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
110110
loss.backward()
111111
self.optimizer.step()
112-
if self.scheduler:
113-
if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__:
114-
self.scheduler.step(loss)
115-
else:
116-
self.scheduler.step()
112+
117113
# only passing the original outputs since we do not care about
118114
# the adversarial performance.
119115
return loss.item(), original_outputs

autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
280280
y=y,
281281
**kwargs
282282
)
283+
283284
# Add snapshots to base network to enable
284285
# predicting with snapshot ensemble
285286
self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice)

examples/tabular/40_advanced/example_custom_configuration_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_search_space_updates():
5959
value_range=['shake-shake'],
6060
default_value='shake-shake')
6161
updates.append(node_name='network_backbone',
62-
hyperparameter='ResNetBackbone:shake_shake_method',
62+
hyperparameter='ResNetBackbone:shake_shake_update_func',
6363
value_range=['M3'],
6464
default_value='M3'
6565
)

test/test_data/test_feature_validator.py

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def test_featurevalidator_supported_types(input_data_featuretest):
204204
assert sparse.issparse(transformed_X)
205205
else:
206206
assert isinstance(transformed_X, np.ndarray)
207-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
208207
assert np.issubdtype(transformed_X.dtype, np.number)
209208
assert validator._is_fitted
210209

@@ -237,9 +236,10 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
237236
validator.fit(input_data_featuretest)
238237
transformed_X = validator.transform(input_data_featuretest)
239238
assert any(pd.isna(input_data_featuretest))
240-
assert any((-1 in categories) or ('-1' in categories) or ('Missing!' in categories) for categories in
241-
validator.encoder.named_transformers_['encoder'].categories_)
242-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
239+
categories_ = validator.column_transformer.\
240+
named_transformers_['categorical_pipeline'].named_steps['onehotencoder'].categories_
241+
assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in
242+
categories_)
243243
assert np.issubdtype(transformed_X.dtype, np.number)
244244
assert validator._is_fitted
245245
assert isinstance(transformed_X, np.ndarray)
@@ -292,7 +292,6 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
292292
else:
293293
raise ValueError(type(input_data_featuretest))
294294
transformed_X = validator.transform(complementary_type)
295-
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
296295
assert np.issubdtype(transformed_X.dtype, np.number)
297296
assert validator._is_fitted
298297

@@ -436,36 +435,29 @@ def test_features_unsupported_calls_are_raised():
436435
expected
437436
"""
438437
validator = TabularFeatureValidator()
439-
with pytest.raises(ValueError, match=r"AutoPyTorch does not support time"):
438+
with pytest.raises(TypeError, match=r".*?Convert the time information to a numerical value"):
440439
validator.fit(
441440
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
442441
)
442+
validator = TabularFeatureValidator()
443443
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
444444
validator.fit({'input1': 1, 'input2': 2})
445-
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
445+
validator = TabularFeatureValidator()
446+
with pytest.raises(TypeError, match=r".*?but input column A has an invalid type `string`.*"):
446447
validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string'))
448+
validator = TabularFeatureValidator()
447449
with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"):
448450
validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]),
449451
X_test=np.array([[1, 2, 3, 4], [4, 5, 6, 7]]),
450452
)
453+
validator = TabularFeatureValidator()
451454
with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"):
452455
validator.transform(np.array([[1, 2, 3], [4, 5, 6]]))
453456

454457

455458
@pytest.mark.parametrize(
456459
'input_data_featuretest',
457460
(
458-
'numpy_numericalonly_nonan',
459-
'numpy_numericalonly_nan',
460-
'pandas_numericalonly_nonan',
461-
'pandas_numericalonly_nan',
462-
'list_numericalonly_nonan',
463-
'list_numericalonly_nan',
464-
# Category in numpy is handled via feat_type
465-
'numpy_categoricalonly_nonan',
466-
'numpy_mixed_nonan',
467-
'numpy_categoricalonly_nan',
468-
'numpy_mixed_nan',
469461
'sparse_bsr_nonan',
470462
'sparse_bsr_nan',
471463
'sparse_coo_nonan',
@@ -483,14 +475,14 @@ def test_features_unsupported_calls_are_raised():
483475
),
484476
indirect=True
485477
)
486-
def test_no_encoder_created(input_data_featuretest):
478+
def test_no_column_transformer_created(input_data_featuretest):
487479
"""
488480
Makes sure that for numerical only features, no encoder is created
489481
"""
490482
validator = TabularFeatureValidator()
491483
validator.fit(input_data_featuretest)
492484
validator.transform(input_data_featuretest)
493-
assert validator.encoder is None
485+
assert validator.column_transformer is None
494486

495487

496488
@pytest.mark.parametrize(
@@ -501,18 +493,18 @@ def test_no_encoder_created(input_data_featuretest):
501493
),
502494
indirect=True
503495
)
504-
def test_encoder_created(input_data_featuretest):
496+
def test_column_transformer_created(input_data_featuretest):
505497
"""
506-
This test ensures an encoder is created if categorical data is provided
498+
This test ensures an column transformer is created if categorical data is provided
507499
"""
508500
validator = TabularFeatureValidator()
509501
validator.fit(input_data_featuretest)
510502
transformed_X = validator.transform(input_data_featuretest)
511-
assert validator.encoder is not None
503+
assert validator.column_transformer is not None
512504

513505
# Make sure that the encoded features are actually encoded. Categorical columns are at
514506
# the start after transformation. In our fixtures, this is also honored prior encode
515-
enc_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest)
507+
cat_columns, _, feature_types = validator._get_columns_info(input_data_featuretest)
516508

517509
# At least one categorical
518510
assert 'categorical' in validator.feat_type
@@ -521,20 +513,13 @@ def test_encoder_created(input_data_featuretest):
521513
if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col]
522514
) for col in input_data_featuretest.columns]):
523515
assert 'numerical' in validator.feat_type
524-
for i, feat_type in enumerate(feature_types):
525-
if 'numerical' in feat_type:
526-
np.testing.assert_array_equal(
527-
transformed_X[:, i],
528-
input_data_featuretest[input_data_featuretest.columns[i]].to_numpy()
529-
)
530-
elif 'categorical' in feat_type:
531-
np.testing.assert_array_equal(
532-
transformed_X[:, i],
533-
# Expect always 0, 1... because we use a ordinal encoder
534-
np.array([0, 1])
535-
)
536-
else:
537-
raise ValueError(feat_type)
516+
# we expect this input to be the fixture 'pandas_mixed_nan'
517+
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., -1.], [0., 1., 1.]]))
518+
else:
519+
np.testing.assert_array_equal(transformed_X, np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]]))
520+
521+
if not all([feat_type in ['numerical', 'categorical'] for feat_type in feature_types]):
522+
raise ValueError("Expected only numerical and categorical feature types")
538523

539524

540525
def test_no_new_category_after_fit():
@@ -566,13 +551,12 @@ def test_unknown_encode_value():
566551
x['c'].cat.add_categories(['NA'], inplace=True)
567552
x.loc[0, 'c'] = 'NA' # unknown value
568553
x_t = validator.transform(x)
569-
# The first row should have a -1 as we added a new categorical there
570-
expected_row = [-1, -41, -3, -987.2]
554+
# The first row should have a 0, 0 as we added a
555+
# new categorical there and one hot encoder marks
556+
# it as all zeros for the transformed column
557+
expected_row = [0.0, 0.0, -0.5584294383572701, 0.5000000000000004, -1.5136598016833485]
571558
assert expected_row == x_t[0].tolist()
572559

573-
# Notice how there is only one column 'c' to encode
574-
assert validator.categories == [list(range(2)) for i in range(1)]
575-
576560

577561
# Actual checks for the features
578562
@pytest.mark.parametrize(
@@ -624,19 +608,20 @@ def test_feature_validator_new_data_after_fit(
624608
assert sparse.issparse(transformed_X)
625609
else:
626610
assert isinstance(transformed_X, np.ndarray)
627-
assert np.shape(X_test) == np.shape(transformed_X)
628611

629612
# And then check proper error messages
630613
if train_data_type == 'pandas':
631614
old_dtypes = copy.deepcopy(validator.dtypes)
632615
validator.dtypes = ['dummy' for dtype in X_train.dtypes]
633-
with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"):
616+
with pytest.raises(ValueError,
617+
match=r"The dtype of the features must not be changed after fit"):
634618
transformed_X = validator.transform(X_test)
635619
validator.dtypes = old_dtypes
636620
if test_data_type == 'pandas':
637621
columns = X_test.columns.tolist()
638622
X_test = X_test[reversed(columns)]
639-
with pytest.raises(ValueError, match=r"Changing the column order of the features"):
623+
with pytest.raises(ValueError,
624+
match=r"The column order of the features must not be changed after fit"):
640625
transformed_X = validator.transform(X_test)
641626

642627

0 commit comments

Comments
 (0)