Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature preprocessors, Loss strategies #86

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cc583c1
ADD Weighted loss
ravinkohli Feb 5, 2021
2ea059c
Now?
ravinkohli Feb 5, 2021
14795cc
Merge branch 'feature_preprocessing' into missing_components
ravinkohli Feb 5, 2021
9f0ed18
Fix tests, flake, mypy
ravinkohli Feb 5, 2021
fb23cef
Fix tests
ravinkohli Feb 5, 2021
a8ea7b5
Fix mypy
ravinkohli Feb 8, 2021
8a618ff
change back sklearn requirement
ravinkohli Feb 8, 2021
8a389b2
Assert for fast ica sklearn bug
ravinkohli Feb 8, 2021
ce1778b
Forgot to add skip
ravinkohli Feb 8, 2021
0b1f3f0
Fix tests, changed num only data to float
ravinkohli Feb 8, 2021
0795b44
removed fast ica
ravinkohli Feb 8, 2021
bf69120
change num only dataset
ravinkohli Feb 8, 2021
e7d8606
Increased number of features in num only
ravinkohli Feb 8, 2021
8a95f61
Increase timeout for pytest
ravinkohli Feb 8, 2021
0a2d74f
ADD tensorboard to requirement
ravinkohli Feb 8, 2021
36b2c22
Fix bug with small_preprocess
ravinkohli Feb 9, 2021
d222826
Fix bug in pytest execution
ravinkohli Feb 9, 2021
90fdcfe
Fix tests
ravinkohli Feb 9, 2021
df9ec6e
ADD error is raised if default not in include
ravinkohli Feb 9, 2021
95378f7
Added dynamic search space for deciding n components in feature prepr…
ravinkohli Feb 9, 2021
0c88cab
Moved back to random configs in tabular test
ravinkohli Feb 9, 2021
3ec87b1
Added floor and ceil and handling of logs
ravinkohli Feb 9, 2021
6546d5c
Fix flake
ravinkohli Feb 9, 2021
9388c32
Remove TruncatedSVD from cs if num numerical ==1
ravinkohli Feb 9, 2021
b6c2cd0
ADD flakyness to network accuracy test
ravinkohli Feb 9, 2021
18f79e2
fix flake
ravinkohli Feb 9, 2021
e974969
remove cla to pytest
ravinkohli Feb 9, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Run tests
run: |
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
python -m pytest -n 2 --timeout=600 --timeout-method=thread --dist load test -sv $codecov
python -m pytest --durations=20 --timeout=300 --timeout-method=thread -v $codecov test
- name: Check for files left behind by test
if: ${{ always() }}
run: |
Expand Down
9 changes: 9 additions & 0 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TransformSubset(Subset):

We achieve so by adding a train flag to the pytorch subset
"""

def __init__(self, dataset: Dataset, indices: Sequence[int], train: bool) -> None:
self.dataset = dataset
self.indices = indices
Expand Down Expand Up @@ -371,3 +372,11 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
'num_classes': self.num_classes,
})
return dataset_properties

def get_required_dataset_info(self) -> Dict[str, Any]:
"""
Returns a dictionary containing required dataset properties to instantiate a pipeline,
"""
info = {'output_type': self.output_type,
'issparse': self.issparse}
return info
17 changes: 13 additions & 4 deletions autoPyTorch/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(self, X: Union[np.ndarray, pd.DataFrame],
# rather to have a performance through time on the test data
if X_test is not None:
X_test, self._test_data_types, _, _, _ = self.interpret_columns(X_test)

# Some quality checks on the data
if self.data_types != self._test_data_types:
raise ValueError(f"The train data inferred types {self.data_types} are "
Expand Down Expand Up @@ -205,8 +204,7 @@ def interpret_columns(self,

return data, data_types, nan_mask, itovs, vtois

def infer_dataset_properties(self, X: Any) \
-> Tuple[List[int], List[int], List[object], int]:
def infer_dataset_properties(self, X: Any) -> Tuple[List[int], List[int], List[object], int]:
"""
Infers the properties of the dataset like
categorical_columns, numerical_columns, categories, num_features
Expand All @@ -225,5 +223,16 @@ def infer_dataset_properties(self, X: Any) \
numerical_columns.append(i)
categories = [np.unique(X.iloc[:, a]).tolist() for a in categorical_columns]
num_features = X.shape[1]

return categorical_columns, numerical_columns, categories, num_features

def get_required_dataset_info(self) -> Dict[str, Any]:
"""
Returns a dictionary containing required dataset properties to instantiate a pipeline,
"""
info = super().get_required_dataset_info()
info.update({
'numerical_columns': self.numerical_columns,
'categorical_columns': self.categorical_columns,
'task_type': self.task_type
})
return info
4 changes: 2 additions & 2 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.evaluation.utils import (
convert_multioutput_multiclass_to_multilabel,
subsampler
convert_multioutput_multiclass_to_multilabel
)
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
Expand All @@ -42,6 +41,7 @@
get_metrics,
)
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.common import subsampler
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
from autoPyTorch.utils.pipeline import get_dataset_requirements
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
AbstractEvaluator,
fit_and_suppress_warnings
)
from autoPyTorch.evaluation.utils import subsampler
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.common import subsampler
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates

__all__ = ['TrainEvaluator', 'eval_function']
Expand Down
8 changes: 0 additions & 8 deletions autoPyTorch/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import numpy as np

import pandas as pd

from smac.runhistory.runhistory import RunValue

__all__ = [
Expand All @@ -16,12 +14,6 @@
]


def subsampler(data: Union[np.ndarray, pd.DataFrame],
x: Union[np.ndarray, List[int]]
) -> Union[np.ndarray, pd.DataFrame]:
return data[x] if isinstance(data, np.ndarray) else data.iloc[x]


def read_queue(queue_: Queue) -> List[RunValue]:
stack: List[RunValue] = []
while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
autoPyTorchTabularPreprocessingComponent
)
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers
from autoPyTorch.utils.common import FitRequirement
from autoPyTorch.utils.common import FitRequirement, subsampler


class TabularColumnTransformer(autoPyTorchTabularPreprocessingComponent):
Expand Down Expand Up @@ -48,7 +48,6 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
"TabularColumnTransformer": an instance of self
"""
self.check_requirements(X, y)

numerical_pipeline = 'drop'
categorical_pipeline = 'drop'

Expand All @@ -67,11 +66,11 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
# Where to get the data -- Prioritize X_train if any else
# get from backend
if 'X_train' in X:
X_train = X['X_train']
X_train = subsampler(X['X_train'], X['train_indices'])
else:
X_train = X['backend'].load_datamanager().train_tensors[0]
self.preprocessor.fit(X_train)

self.preprocessor.fit(X_train)
return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def get_hyperparameter_search_space(self,
# add only no encoder to choice hyperparameters in case the dataset is only numerical
if len(dataset_properties['categorical_columns']) == 0:
default = 'NoEncoder'
if include is not None and default not in include:
raise ValueError("Provided {} in include, however, the dataset "
"is incompatible with it".format(include))
preprocessor = CSH.CategoricalHyperparameter('__choice__',
['NoEncoder'],
default_value=default)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from math import ceil, floor
from typing import Any, Dict, Optional, Tuple, Union

from ConfigSpace.conditions import EqualsCondition, InCondition
from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
UniformFloatHyperparameter,
UniformIntegerHyperparameter,
)

import numpy as np

import sklearn.decomposition
from sklearn.base import BaseEstimator

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.\
base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent
from autoPyTorch.utils.common import FitRequirement


class KernelPCA(autoPyTorchFeaturePreprocessingComponent):
def __init__(self, n_components: int = 10,
kernel: str = 'rbf', degree: int = 3,
gamma: float = 0.01, coef0: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState]] = None
) -> None:
self.n_components = n_components
self.kernel = kernel
self.degree = degree
self.gamma = gamma
self.coef0 = coef0
self.random_state = random_state
super().__init__()

self.add_fit_requirements([
FitRequirement('issparse', (bool,), user_defined=True, dataset_property=True)])

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.preprocessor['numerical'] = sklearn.decomposition.KernelPCA(
n_components=self.n_components, kernel=self.kernel,
degree=self.degree, gamma=self.gamma, coef0=self.coef0,
remove_zero_eig=True, random_state=self.random_state)

return self

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, str]] = None,
n_components: Tuple[Tuple, float] = ((0.5, 0.9), 0.5),
kernel: Tuple[Tuple, str] = (('poly', 'rbf', 'sigmoid', 'cosine'), 'rbf'),
gamma: Tuple[Tuple, float, bool] = ((3.0517578125e-05, 8), 0.01, True),
degree: Tuple[Tuple, int] = ((2, 5), 3),
coef0: Tuple[Tuple, float] = ((-1, 1), 0)
) -> ConfigurationSpace:

if dataset_properties is not None:
n_features = len(dataset_properties['numerical_columns'])
n_components = ((floor(n_components[0][0] * n_features), ceil(n_components[0][1] * n_features)),
ceil(n_components[1] * n_features))
else:
n_components = ((10, 2000), 100)

n_components = UniformIntegerHyperparameter(
"n_components", lower=n_components[0][0], upper=n_components[0][1], default_value=n_components[1])
kernel_hp = CategoricalHyperparameter('kernel', choices=kernel[0], default_value=kernel[1])
gamma = UniformFloatHyperparameter(
"gamma",
lower=gamma[0][0], upper=gamma[0][1],
log=gamma[2],
default_value=gamma[1],
)
coef0 = UniformFloatHyperparameter("coef0", lower=coef0[0][0], upper=coef0[0][1], default_value=coef0[1])
cs = ConfigurationSpace()
cs.add_hyperparameters([n_components, kernel_hp, gamma, coef0])

if "poly" in kernel_hp.choices:
degree = UniformIntegerHyperparameter('degree', lower=degree[0][0], upper=degree[0][1],
default_value=degree[1])
cs.add_hyperparameters([degree])
degree_depends_on_poly = EqualsCondition(degree, kernel_hp, "poly")
cs.add_conditions([degree_depends_on_poly])
kernels = []
if "sigmoid" in kernel_hp.choices:
kernels.append("sigmoid")
if "poly" in kernel_hp.choices:
kernels.append("poly")
coef0_condition = InCondition(coef0, kernel_hp, kernels)
kernels = []
if "rbf" in kernel_hp.choices:
kernels.append("rbf")
if "poly" in kernel_hp.choices:
kernels.append("poly")
gamma_condition = InCondition(gamma, kernel_hp, kernels)
cs.add_conditions([coef0_condition, gamma_condition])
return cs

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
return {'shortname': 'KernelPCA',
'name': 'Kernel Principal Component Analysis',
'handles_sparse': True
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional, Union

import numpy as np

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.\
base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent


class NoFeaturePreprocessor(autoPyTorchFeaturePreprocessingComponent):
"""
Don't perform feature preprocessing on categorical features
"""
def __init__(self,
random_state: Optional[Union[np.random.RandomState, int]] = None
):
super().__init__()
self.random_state = random_state

def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchFeaturePreprocessingComponent:
"""
The fit function calls the fit function of the underlying model
and returns the transformed array.
Args:
X (np.ndarray): input features
y (Optional[np.ndarray]): input labels

Returns:
instance of self
"""
self.check_requirements(X, y)

return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
Adds the self into the 'X' dictionary and returns it.
Args:
X (Dict[str, Any]): 'X' dictionary

Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
X.update({'feature_preprocessor': self.preprocessor})
return X

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
return {
'shortname': 'NoFeaturePreprocessing',
'name': 'No Feature Preprocessing',
'handles_sparse': True
}
Loading