Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature preprocessors, Loss strategies #86

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cc583c1
ADD Weighted loss
ravinkohli Feb 5, 2021
2ea059c
Now?
ravinkohli Feb 5, 2021
14795cc
Merge branch 'feature_preprocessing' into missing_components
ravinkohli Feb 5, 2021
9f0ed18
Fix tests, flake, mypy
ravinkohli Feb 5, 2021
fb23cef
Fix tests
ravinkohli Feb 5, 2021
a8ea7b5
Fix mypy
ravinkohli Feb 8, 2021
8a618ff
change back sklearn requirement
ravinkohli Feb 8, 2021
8a389b2
Assert for fast ica sklearn bug
ravinkohli Feb 8, 2021
ce1778b
Forgot to add skip
ravinkohli Feb 8, 2021
0b1f3f0
Fix tests, changed num only data to float
ravinkohli Feb 8, 2021
0795b44
removed fast ica
ravinkohli Feb 8, 2021
bf69120
change num only dataset
ravinkohli Feb 8, 2021
e7d8606
Increased number of features in num only
ravinkohli Feb 8, 2021
8a95f61
Increase timeout for pytest
ravinkohli Feb 8, 2021
0a2d74f
ADD tensorboard to requirement
ravinkohli Feb 8, 2021
36b2c22
Fix bug with small_preprocess
ravinkohli Feb 9, 2021
d222826
Fix bug in pytest execution
ravinkohli Feb 9, 2021
90fdcfe
Fix tests
ravinkohli Feb 9, 2021
df9ec6e
ADD error is raised if default not in include
ravinkohli Feb 9, 2021
95378f7
Added dynamic search space for deciding n components in feature prepr…
ravinkohli Feb 9, 2021
0c88cab
Moved back to random configs in tabular test
ravinkohli Feb 9, 2021
3ec87b1
Added floor and ceil and handling of logs
ravinkohli Feb 9, 2021
6546d5c
Fix flake
ravinkohli Feb 9, 2021
9388c32
Remove TruncatedSVD from cs if num numerical ==1
ravinkohli Feb 9, 2021
b6c2cd0
ADD flakyness to network accuracy test
ravinkohli Feb 9, 2021
18f79e2
fix flake
ravinkohli Feb 9, 2021
e974969
remove cla to pytest
ravinkohli Feb 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TransformSubset(Subset):

We achieve so by adding a train flag to the pytorch subset
"""

def __init__(self, dataset: Dataset, indices: Sequence[int], train: bool) -> None:
self.dataset = dataset
self.indices = indices
Expand All @@ -59,17 +60,17 @@ def __getitem__(self, idx: int) -> np.ndarray:

class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(
self,
train_tensors: BASE_DATASET_INPUT,
dataset_name: Optional[str] = None,
val_tensors: Optional[BASE_DATASET_INPUT] = None,
test_tensors: Optional[BASE_DATASET_INPUT] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
seed: Optional[int] = 42,
train_transforms: Optional[torchvision.transforms.Compose] = None,
val_transforms: Optional[torchvision.transforms.Compose] = None,
self,
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
train_tensors: BASE_DATASET_INPUT,
dataset_name: Optional[str] = None,
val_tensors: Optional[BASE_DATASET_INPUT] = None,
test_tensors: Optional[BASE_DATASET_INPUT] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
seed: Optional[int] = 42,
train_transforms: Optional[torchvision.transforms.Compose] = None,
val_transforms: Optional[torchvision.transforms.Compose] = None,
):
"""
Base class for datasets used in AutoPyTorch
Expand Down Expand Up @@ -245,9 +246,9 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
return splits

def create_cross_val_splits(
self,
cross_val_type: CrossValTypes,
num_splits: int
self,
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
cross_val_type: CrossValTypes,
num_splits: int
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
"""
This function creates the cross validation split for the given task.
Expand Down Expand Up @@ -277,9 +278,9 @@ def create_cross_val_splits(
return splits

def create_holdout_val_split(
self,
holdout_val_type: HoldoutValTypes,
val_share: float,
self,
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
holdout_val_type: HoldoutValTypes,
val_share: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""
This function creates the holdout split for the given task.
Expand Down Expand Up @@ -371,3 +372,11 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
'num_classes': self.num_classes,
})
return dataset_properties

def get_required_dataset_info(self) -> Dict[str, Any]:
"""
Returns a dictionary containing required dataset properties to instantiate a pipeline,
"""
info = {'output_type': self.output_type,
'issparse': self.issparse}
return info
14 changes: 12 additions & 2 deletions autoPyTorch/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(self, X: Union[np.ndarray, pd.DataFrame],
# rather to have a performance through time on the test data
if X_test is not None:
X_test, self._test_data_types, _, _, _ = self.interpret_columns(X_test)

# Some quality checks on the data
if self.data_types != self._test_data_types:
raise ValueError(f"The train data inferred types {self.data_types} are "
Expand Down Expand Up @@ -225,5 +224,16 @@ def infer_dataset_properties(self, X: Any) \
numerical_columns.append(i)
categories = [np.unique(X.iloc[:, a]).tolist() for a in categorical_columns]
num_features = X.shape[1]

return categorical_columns, numerical_columns, categories, num_features

def get_required_dataset_info(self) -> Dict[str, Any]:
"""
Returns a dictionary containing required dataset properties to instantiate a pipeline,
"""
info = super().get_required_dataset_info()
info.update({
'numerical_columns': self.numerical_columns,
'categorical_columns': self.categorical_columns,
'task_type': self.task_type
})
return info
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Tuple, Union

from ConfigSpace.conditions import EqualsCondition
from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
UniformIntegerHyperparameter,
)

import numpy as np

import sklearn.decomposition
from sklearn.base import BaseEstimator

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.\
base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent


class FastICA(autoPyTorchFeaturePreprocessingComponent):
def __init__(self, n_components: int = 100,
algorithm: str = 'parallel',
whiten: bool = False,
fun: str = 'logcosh',
random_state: Optional[Union[int, np.random.RandomState]] = None
) -> None:
self.n_components = n_components
self.algorithm = algorithm
self.whiten = whiten
self.fun = fun
self.random_state = random_state

super().__init__()

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.preprocessor['numerical'] = sklearn.decomposition.FastICA(
n_components=self.n_components, algorithm=self.algorithm,
fun=self.fun, whiten=self.whiten, random_state=self.random_state
)

return self

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, str]] = None,
n_components: Tuple[Tuple, int] = ((10, 2000), 100),
algorithm: Tuple[Tuple, str] = (('parallel', 'deflation'), 'parallel'),
whiten: Tuple[Tuple, bool] = ((True, False), False),
fun: Tuple[Tuple, str] = (('logcosh', 'exp', 'cube'), 'logcosh')
) -> ConfigurationSpace:
n_components = UniformIntegerHyperparameter(
"n_components", lower=n_components[0][0], upper=n_components[0][1], default_value=n_components[1])
algorithm = CategoricalHyperparameter('algorithm', choices=algorithm[0], default_value=algorithm[1])
whiten = CategoricalHyperparameter('whiten', choices=whiten[0], default_value=whiten[1])
fun = CategoricalHyperparameter('fun', choices=fun[0], default_value=fun[1])
cs = ConfigurationSpace()
cs.add_hyperparameters([n_components, algorithm, whiten, fun])

cs.add_condition(EqualsCondition(n_components, whiten, True))

return cs

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
return {'shortname': 'FastICA',
'name': 'Fast Independent Component Analysis',
'handles_sparse': True
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any, Dict, Optional, Tuple, Union

from ConfigSpace.conditions import EqualsCondition, InCondition
from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
UniformFloatHyperparameter,
UniformIntegerHyperparameter,
)

import numpy as np

import sklearn.decomposition
from sklearn.base import BaseEstimator

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.\
base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent
from autoPyTorch.utils.common import FitRequirement


class KernelPCA(autoPyTorchFeaturePreprocessingComponent):
def __init__(self, n_components: int = 100,
kernel: str = 'rbf', degree: int = 3,
gamma: float = 0.01, coef0: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState]] = None
) -> None:
self.n_components = n_components
self.kernel = kernel
self.degree = degree
self.gamma = gamma
self.coef0 = coef0
self.random_state = random_state
super().__init__()

self.add_fit_requirements([
FitRequirement('issparse', (bool,), user_defined=True, dataset_property=True)])

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.preprocessor['numerical'] = sklearn.decomposition.KernelPCA(
n_components=self.n_components, kernel=self.kernel,
degree=self.degree, gamma=self.gamma, coef0=self.coef0,
remove_zero_eig=True, random_state=self.random_state)
#
# # Raise an informative error message, equation is based ~line 249 in
# # KernelPCA.py in scikit-learn
# if len(self.preprocessor.alphas_ / self.preprocessor.lambdas_) == 0:
# raise ValueError('KernelPCA removed all features!')

return self

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, str]] = None,
n_components: Tuple[Tuple, int] = ((10, 2000), 100),
kernel: Tuple[Tuple, str] = (('poly', 'rbf', 'sigmoid', 'cosine'), 'rbf'),
gamma: Tuple[Tuple, float, bool] = ((3.0517578125e-05, 8), 0.01, True),
degree: Tuple[Tuple, int] = ((2, 5), 3),
coef0: Tuple[Tuple, float] = ((-1, 1), 0)
) -> ConfigurationSpace:
n_components = UniformIntegerHyperparameter(
"n_components", lower=n_components[0][0], upper=n_components[0][1], default_value=n_components[1])
kernel_hp = CategoricalHyperparameter('kernel', choices=kernel[0], default_value=kernel[1])
gamma = UniformFloatHyperparameter(
"gamma",
lower=gamma[0][0], upper=gamma[0][1],
log=gamma[2],
default_value=gamma[1],
)
coef0 = UniformFloatHyperparameter("coef0", lower=coef0[0][0], upper=coef0[0][1], default_value=coef0[1])
cs = ConfigurationSpace()
cs.add_hyperparameters([n_components, kernel_hp, gamma, coef0])

if "poly" in kernel_hp.choices:
degree = UniformIntegerHyperparameter('degree', lower=degree[0][0], upper=degree[0][1],
default_value=degree[1])
cs.add_hyperparameters([degree])
degree_depends_on_poly = EqualsCondition(degree, kernel_hp, "poly")
cs.add_conditions([degree_depends_on_poly])
kernels = []
if "sigmoid" in kernel_hp.choices:
kernels.append("sigmoid")
if "poly" in kernel_hp.choices:
kernels.append("poly")
coef0_condition = InCondition(coef0, kernel_hp, kernels)
kernels = []
if "rbf" in kernel_hp.choices:
kernels.append("rbf")
if "poly" in kernel_hp.choices:
kernels.append("poly")
gamma_condition = InCondition(gamma, kernel_hp, kernels)
cs.add_conditions([coef0_condition, gamma_condition])
return cs

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
return {'shortname': 'KernelPCA',
'name': 'Kernel Principal Component Analysis',
'handles_sparse': True
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional, Union

import numpy as np

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.\
base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent


class NoFeaturePreprocessor(autoPyTorchFeaturePreprocessingComponent):
"""
Don't perform feature preprocessing on categorical features
"""
def __init__(self,
random_state: Optional[Union[np.random.RandomState, int]] = None
):
super().__init__()
self.random_state = random_state

def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchFeaturePreprocessingComponent:
"""
The fit function calls the fit function of the underlying model
and returns the transformed array.
Args:
X (np.ndarray): input features
y (Optional[np.ndarray]): input labels

Returns:
instance of self
"""
self.check_requirements(X, y)

return self

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
Adds the self into the 'X' dictionary and returns it.
Args:
X (Dict[str, Any]): 'X' dictionary

Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
X.update({'feature_preprocessor': self.preprocessor})
return X

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
return {
'shortname': 'NoFeaturePreprocessing',
'name': 'No Feature Preprocessing',
'handles_sparse': True
}
Loading