Skip to content

Commit 09ad0d7

Browse files
ArlindKadraravinkohli
authored andcommitted
Cocktail hotfixes (#245)
* Fixes for the development branch and regularization cocktails * Update implementation * Fix unit tests temporarily * Implementation update and bug fixes * Removing unecessary code * Addressing Ravin's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments
1 parent d140442 commit 09ad0d7

File tree

12 files changed

+404
-48
lines changed

12 files changed

+404
-48
lines changed

autoPyTorch/api/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None:
407407
None
408408
"""
409409
unknown_keys = []
410-
for option, value in pipeline_config_kwargs.items():
410+
for option in pipeline_config_kwargs.keys():
411411
if option in self.pipeline_options.keys():
412412
pass
413413
else:

autoPyTorch/api/tabular_classification.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,16 @@ def search(
389389
y_test=y_test,
390390
resampling_strategy=self.resampling_strategy,
391391
resampling_strategy_args=self.resampling_strategy_args,
392-
dataset_name=dataset_name)
392+
dataset_name=dataset_name
393+
)
394+
395+
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
396+
raise ValueError(
397+
'Hyperparameter optimization requires a validation split. '
398+
'Expected `self.resampling_strategy` to be either '
399+
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
400+
)
401+
393402

394403
return self._search(
395404
dataset=self.dataset,
@@ -430,23 +439,23 @@ def predict(
430439
raise ValueError("predict() is only supported after calling search. Kindly call first "
431440
"the estimator search() method.")
432441

433-
X_test = self.InputValidator.feature_validator.transform(X_test)
442+
X_test = self.input_validator.feature_validator.transform(X_test)
434443
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
435444
n_jobs=n_jobs)
436445

437-
if self.InputValidator.target_validator.is_single_column_target():
446+
if self.input_validator.target_validator.is_single_column_target():
438447
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
439448
else:
440449
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
441450

442451
# Allow to predict in the original domain -- that is, the user is not interested
443452
# in our encoded values
444-
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
453+
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
445454

446455
def predict_proba(self,
447456
X_test: Union[np.ndarray, pd.DataFrame, List],
448457
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
449-
if self.InputValidator is None or not self.InputValidator._is_fitted:
458+
if self.input_validator is None or not self.input_validator._is_fitted:
450459
raise ValueError("predict() is only supported after calling search. Kindly call first "
451460
"the estimator search() method.")
452461
X_test = self.InputValidator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,16 @@ def search(
389389
y_test=y_test,
390390
resampling_strategy=self.resampling_strategy,
391391
resampling_strategy_args=self.resampling_strategy_args,
392-
dataset_name=dataset_name)
392+
dataset_name=dataset_name
393+
)
394+
395+
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
396+
raise ValueError(
397+
'Hyperparameter optimization requires a validation split. '
398+
'Expected `self.resampling_strategy` to be either '
399+
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
400+
)
401+
393402

394403
return self._search(
395404
dataset=self.dataset,
@@ -416,14 +425,14 @@ def predict(
416425
batch_size: Optional[int] = None,
417426
n_jobs: int = 1
418427
) -> np.ndarray:
419-
if self.InputValidator is None or not self.InputValidator._is_fitted:
428+
if self.input_validator is None or not self.input_validator._is_fitted:
420429
raise ValueError("predict() is only supported after calling search. Kindly call first "
421430
"the estimator search() method.")
422431

423-
X_test = self.InputValidator.feature_validator.transform(X_test)
432+
X_test = self.input_validator.feature_validator.transform(X_test)
424433
predicted_values = super().predict(X_test, batch_size=batch_size,
425434
n_jobs=n_jobs)
426435

427436
# Allow to predict in the original domain -- that is, the user is not interested
428437
# in our encoded values
429-
return self.InputValidator.target_validator.inverse_transform(predicted_values)
438+
return self.input_validator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _get_columns_to_encode(
391391
feat_type = []
392392

393393
# Make sure each column is a valid type
394-
for i, column in enumerate(X.columns):
394+
for column in X.columns:
395395
if X[column].dtype.name in ['category', 'bool']:
396396

397397
transformed_columns.append(column)
@@ -512,7 +512,7 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
512512
X[key] = X[key].astype(dtype.name)
513513
except Exception as e:
514514
# Try inference if possible
515-
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
515+
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
516516
pass
517517
else:
518518
X = X.infer_objects()

0 commit comments

Comments
 (0)