diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index a048e2054..89ec14b8c 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -39,6 +39,7 @@ STRING_TO_TASK_TYPES, ) from autoPyTorch.data.base_validator import BaseInputValidator +from autoPyTorch.data.utils import DatasetCompressionSpec from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, @@ -299,6 +300,7 @@ def _get_dataset_input_validator( resampling_strategy: Optional[ResamplingStrategies] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, + dataset_compression: Optional[DatasetCompressionSpec] = None, ) -> Tuple[BaseDataset, BaseInputValidator]: """ Returns an object of a child class of `BaseDataset` and @@ -323,6 +325,9 @@ def _get_dataset_input_validator( in ```datasets/resampling_strategy.py```. dataset_name (Optional[str]): name of the dataset, used as experiment name. + dataset_compression (Optional[DatasetCompressionSpec]): + specifications for dataset compression. For more info check + documentation for `BaseTask.get_dataset`. Returns: BaseDataset: @@ -341,6 +346,7 @@ def get_dataset( resampling_strategy: Optional[ResamplingStrategies] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, + dataset_compression: Optional[DatasetCompressionSpec] = None, ) -> BaseDataset: """ Returns an object of a child class of `BaseDataset` according to the current task. @@ -363,6 +369,38 @@ def get_dataset( in ```datasets/resampling_strategy.py```. dataset_name (Optional[str]): name of the dataset, used as experiment name. + dataset_compression (Optional[DatasetCompressionSpec]): + We compress datasets so that they fit into some predefined amount of memory. + **NOTE** + + You can also pass your own configuration with the same keys and choosing + from the available ``"methods"``. + The available options are described here: + **memory_allocation** + Absolute memory in MB, e.g. 10MB is ``"memory_allocation": 10``. + The memory used by the dataset is checked after each reduction method is + performed. If the dataset fits into the allocated memory, any further methods + listed in ``"methods"`` will not be performed. + It can be either float or int. + + **methods** + We currently provide the following methods for reducing the dataset size. + These can be provided in a list and are performed in the order as given. + * ``"precision"`` - + We reduce floating point precision as follows: + * ``np.float128 -> np.float64`` + * ``np.float96 -> np.float64`` + * ``np.float64 -> np.float32`` + * pandas dataframes are reduced using the downcast option of `pd.to_numeric` + to the lowest possible precision. + * ``subsample`` - + We subsample data such that it **fits directly into + the memory allocation** ``memory_allocation * memory_limit``. + Therefore, this should likely be the last method listed in + ``"methods"``. + Subsampling takes into account classification labels and stratifies + accordingly. We guarantee that at least one occurrence of each + label is included in the sampled set. Returns: BaseDataset: @@ -375,7 +413,8 @@ def get_dataset( y_test=y_test, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, - dataset_name=dataset_name) + dataset_name=dataset_name, + dataset_compression=dataset_compression) return dataset diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 684c22a7b..3d80a0338 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -12,7 +12,8 @@ ) from autoPyTorch.data.tabular_validator import TabularInputValidator from autoPyTorch.data.utils import ( - get_dataset_compression_mapping + DatasetCompressionSpec, + get_dataset_compression_mapping, ) from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.datasets.resampling_strategy import ( @@ -166,7 +167,7 @@ def _get_dataset_input_validator( resampling_strategy: Optional[ResamplingStrategies] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, - dataset_compression: Optional[Mapping[str, Any]] = None, + dataset_compression: Optional[DatasetCompressionSpec] = None, ) -> Tuple[TabularDataset, TabularInputValidator]: """ Returns an object of `TabularDataset` and an object of @@ -190,6 +191,10 @@ def _get_dataset_input_validator( in ```datasets/resampling_strategy.py```. dataset_name (Optional[str]): name of the dataset, used as experiment name. + dataset_compression (Optional[DatasetCompressionSpec]): + specifications for dataset compression. For more info check + documentation for `BaseTask.get_dataset`. + Returns: TabularDataset: the dataset object. @@ -396,14 +401,23 @@ def search( listed in ``"methods"`` will not be performed. **methods** - We currently provide the following methods for reducing the dataset size. - These can be provided in a list and are performed in the order as given. - * ``"precision"`` - We reduce floating point precision as follows: - * ``np.float128 -> np.float64`` - * ``np.float96 -> np.float64`` - * ``np.float64 -> np.float32`` - * pandas dataframes are reduced using the downcast option of `pd.to_numeric` - to the lowest possible precision. + We currently provide the following methods for reducing the dataset size. + These can be provided in a list and are performed in the order as given. + * ``"precision"`` - + We reduce floating point precision as follows: + * ``np.float128 -> np.float64`` + * ``np.float96 -> np.float64`` + * ``np.float64 -> np.float32`` + * pandas dataframes are reduced using the downcast option of `pd.to_numeric` + to the lowest possible precision. + * ``subsample`` - + We subsample data such that it **fits directly into + the memory allocation** ``memory_allocation * memory_limit``. + Therefore, this should likely be the last method listed in + ``"methods"``. + Subsampling takes into account classification labels and stratifies + accordingly. We guarantee that at least one occurrence of each + label is included in the sampled set. Returns: self diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index d766bad68..fa8cf8081 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -12,7 +12,8 @@ ) from autoPyTorch.data.tabular_validator import TabularInputValidator from autoPyTorch.data.utils import ( - get_dataset_compression_mapping + DatasetCompressionSpec, + get_dataset_compression_mapping, ) from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.datasets.resampling_strategy import ( @@ -167,7 +168,7 @@ def _get_dataset_input_validator( resampling_strategy: Optional[ResamplingStrategies] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, - dataset_compression: Optional[Mapping[str, Any]] = None, + dataset_compression: Optional[DatasetCompressionSpec] = None, ) -> Tuple[TabularDataset, TabularInputValidator]: """ Returns an object of `TabularDataset` and an object of @@ -191,6 +192,9 @@ def _get_dataset_input_validator( in ```datasets/resampling_strategy.py```. dataset_name (Optional[str]): name of the dataset, used as experiment name. + dataset_compression (Optional[DatasetCompressionSpec]): + specifications for dataset compression. For more info check + documentation for `BaseTask.get_dataset`. Returns: TabularDataset: the dataset object. @@ -397,14 +401,23 @@ def search( listed in ``"methods"`` will not be performed. **methods** - We currently provide the following methods for reducing the dataset size. - These can be provided in a list and are performed in the order as given. - * ``"precision"`` - We reduce floating point precision as follows: - * ``np.float128 -> np.float64`` - * ``np.float96 -> np.float64`` - * ``np.float64 -> np.float32`` - * pandas dataframes are reduced using the downcast option of `pd.to_numeric` - to the lowest possible precision. + We currently provide the following methods for reducing the dataset size. + These can be provided in a list and are performed in the order as given. + * ``"precision"`` - + We reduce floating point precision as follows: + * ``np.float128 -> np.float64`` + * ``np.float96 -> np.float64`` + * ``np.float64 -> np.float32`` + * pandas dataframes are reduced using the downcast option of `pd.to_numeric` + to the lowest possible precision. + * ``subsample`` - + We subsample data such that it **fits directly into + the memory allocation** ``memory_allocation * memory_limit``. + Therefore, this should likely be the last method listed in + ``"methods"``. + Subsampling takes into account classification labels and stratifies + accordingly. We guarantee that at least one occurrence of each + label is included in the sampled set. Returns: self diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 3e8c316b0..8dad37205 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -1,6 +1,6 @@ import functools from logging import Logger -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, Tuple, Union, cast import numpy as np @@ -18,11 +18,6 @@ from sklearn.pipeline import make_pipeline from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes -from autoPyTorch.data.utils import ( - DatasetCompressionInputType, - DatasetDTypeContainerType, - reduce_dataset_size_if_too_large -) from autoPyTorch.utils.common import ispandas from autoPyTorch.utils.logging_ import PicklableClientLogger @@ -103,10 +98,7 @@ class TabularFeatureValidator(BaseFeatureValidator): def __init__( self, logger: Optional[Union[PicklableClientLogger, Logger]] = None, - dataset_compression: Optional[Mapping[str, Any]] = None, - ) -> None: - self._dataset_compression = dataset_compression - self._reduced_dtype: Optional[DatasetDTypeContainerType] = None + ): super().__init__(logger) @staticmethod @@ -290,38 +282,8 @@ def transform( "numerical or categorical values.") raise e - X = self._compress_dataset(X) - return X - # TODO: modify once we have added subsampling as well. - def _compress_dataset(self, X: DatasetCompressionInputType) -> DatasetCompressionInputType: - """ - Compress the dataset. This function ensures that - the testing data is converted to the same dtype as - the training data. - - - Args: - X (DatasetCompressionInputType): - Dataset - - Returns: - DatasetCompressionInputType: - Compressed dataset. - """ - is_dataframe = ispandas(X) - is_reducible_type = isinstance(X, np.ndarray) or issparse(X) or is_dataframe - if not is_reducible_type or self._dataset_compression is None: - return X - elif self._reduced_dtype is not None: - X = X.astype(self._reduced_dtype) - return X - else: - X = reduce_dataset_size_if_too_large(X, **self._dataset_compression) - self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype - return X - def _check_data( self, X: SupportedFeatTypes, diff --git a/autoPyTorch/data/tabular_validator.py b/autoPyTorch/data/tabular_validator.py index 4db415f93..492327fbe 100644 --- a/autoPyTorch/data/tabular_validator.py +++ b/autoPyTorch/data/tabular_validator.py @@ -1,10 +1,21 @@ # -*- encoding: utf-8 -*- import logging -from typing import Any, Mapping, Optional, Union +from typing import Optional, Tuple, Union + +import numpy as np + +from scipy.sparse import issparse from autoPyTorch.data.base_validator import BaseInputValidator -from autoPyTorch.data.tabular_feature_validator import TabularFeatureValidator -from autoPyTorch.data.tabular_target_validator import TabularTargetValidator +from autoPyTorch.data.tabular_feature_validator import SupportedFeatTypes, TabularFeatureValidator +from autoPyTorch.data.tabular_target_validator import SupportedTargetTypes, TabularTargetValidator +from autoPyTorch.data.utils import ( + DatasetCompressionInputType, + DatasetCompressionSpec, + DatasetDTypeContainerType, + reduce_dataset_size_if_too_large +) +from autoPyTorch.utils.common import ispandas from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger @@ -27,16 +38,22 @@ class TabularInputValidator(BaseInputValidator): target_validator (TargetValidator): A TargetValidator instance used to validate and encode (in case of classification) the target values + dataset_compression (Optional[DatasetCompressionSpec]): + specifications for dataset compression. For more info check + documentation for `BaseTask.get_dataset`. """ def __init__( self, is_classification: bool = False, logger_port: Optional[int] = None, - dataset_compression: Optional[Mapping[str, Any]] = None, - ) -> None: + dataset_compression: Optional[DatasetCompressionSpec] = None, + seed: int = 42, + ): + self.dataset_compression = dataset_compression + self._reduced_dtype: Optional[DatasetDTypeContainerType] = None self.is_classification = is_classification self.logger_port = logger_port - self.dataset_compression = dataset_compression + self.seed = seed if self.logger_port is not None: self.logger: Union[logging.Logger, PicklableClientLogger] = get_named_client_logger( name='Validation', @@ -46,10 +63,59 @@ def __init__( self.logger = logging.getLogger('Validation') self.feature_validator = TabularFeatureValidator( - dataset_compression=self.dataset_compression, logger=self.logger) self.target_validator = TabularTargetValidator( is_classification=self.is_classification, logger=self.logger ) self._is_fitted = False + + def _compress_dataset( + self, + X: DatasetCompressionInputType, + y: SupportedTargetTypes, + ) -> DatasetCompressionInputType: + """ + Compress the dataset. This function ensures that + the testing data is converted to the same dtype as + the training data. + See `autoPyTorch.data.utils.reduce_dataset_size_if_too_large` + for more information. + + Args: + X (DatasetCompressionInputType): + features of dataset + y (SupportedTargetTypes): + targets of dataset + Returns: + DatasetCompressionInputType: + Compressed dataset. + """ + is_dataframe = ispandas(X) + is_reducible_type = isinstance(X, np.ndarray) or issparse(X) or is_dataframe + if not is_reducible_type or self.dataset_compression is None: + return X, y + elif self._reduced_dtype is not None: + X = X.astype(self._reduced_dtype) + return X, y + else: + X, y = reduce_dataset_size_if_too_large( + X, + y=y, + is_classification=self.is_classification, + random_state=self.seed, + **self.dataset_compression # type: ignore [arg-type] + ) + self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype + return X, y + + def transform( + self, + X: SupportedFeatTypes, + y: Optional[SupportedTargetTypes] = None, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + + X, y = super().transform(X, y) + X_reduced, y_reduced = self._compress_dataset(X, y) + + return X_reduced, y_reduced diff --git a/autoPyTorch/data/utils.py b/autoPyTorch/data/utils.py index 03375ce27..20ad5612e 100644 --- a/autoPyTorch/data/utils.py +++ b/autoPyTorch/data/utils.py @@ -1,6 +1,5 @@ # Implementation used from https://github.com/automl/auto-sklearn/blob/development/autosklearn/util/data.py import warnings -from math import floor from typing import ( Any, Dict, @@ -21,8 +20,13 @@ from scipy.sparse import issparse, spmatrix -from autoPyTorch.utils.common import ispandas +from sklearn.model_selection import StratifiedShuffleSplit, train_test_split +from sklearn.model_selection._split import _validate_shuffle_split +from sklearn.utils import _approximate_mode, check_random_state +from sklearn.utils.validation import _num_samples, check_array +from autoPyTorch.data.base_target_validator import SupportedTargetTypes +from autoPyTorch.utils.common import ispandas # TODO: TypedDict with python 3.8 # @@ -35,10 +39,108 @@ # Default specification for arg `dataset_compression` default_dataset_compression_arg: DatasetCompressionSpec = { "memory_allocation": 0.1, - "methods": ["precision"] + "methods": ["precision", "subsample"] } +class CustomStratifiedShuffleSplit(StratifiedShuffleSplit): + """Splitter that deals with classes with too few samples""" + + def _iter_indices(self, X, y, groups=None): # type: ignore + n_samples = _num_samples(X) + y = check_array(y, ensure_2d=False, dtype=None) + n_train, n_test = _validate_shuffle_split( + n_samples, + self.test_size, + self.train_size, + default_test_size=self._default_test_size, + ) + + if y.ndim == 2: + # for multi-label y, map each distinct row to a string repr + # using join because str(row) uses an ellipsis if len(row) > 1000 + y = np.array([" ".join(row.astype("str")) for row in y]) + + classes, y_indices = np.unique(y, return_inverse=True) + n_classes = classes.shape[0] + + class_counts = np.bincount(y_indices) + + if n_train < n_classes: + raise ValueError( + "The train_size = %d should be greater or " + "equal to the number of classes = %d" % (n_train, n_classes) + ) + if n_test < n_classes: + raise ValueError( + "The test_size = %d should be greater or " + "equal to the number of classes = %d" % (n_test, n_classes) + ) + + # Find the sorted list of instances for each class: + # (np.unique above performs a sort, so code is O(n logn) already) + class_indices = np.split( + np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1] + ) + + rng = check_random_state(self.random_state) + + for _ in range(self.n_splits): + # if there are ties in the class-counts, we want + # to make sure to break them anew in each iteration + n_i = _approximate_mode(class_counts, n_train, rng) + class_counts_remaining = class_counts - n_i + t_i = _approximate_mode(class_counts_remaining, n_test, rng) + train = [] + test = [] + + # NOTE: Adapting for unique instances + # + # Each list n_i, t_i represent the list of class in the + # training_set and test_set resepectively. + # + # n_i = [100, 100, 0, 3] # 100 of class '0', 0 of class '2' + # t_i = [300, 300, 1, 3] # 300 of class '0', 1 of class '2' + # + # To support unique labels such as class '2', which only has one sample + # between both n_i and t_i, we need to make sure that n_i has at least + # one sample of all classes. There is also the extra check to ensure + # that the sizes stay the same. + # + # n_i = [ 99, 100, 1, 3] # 100 of class '0', 0 of class '2' + # | ^ + # v | + # t_i = [301, 300, 0, 3] # 300 of class '0', 1 of class '2' + # + for i, class_count in enumerate(n_i): + if class_count == 0: + t_i[i] -= 1 + n_i[i] += 1 + + j = np.argmax(n_i) + if n_i[j] == 1: + warnings.warn( + "Can't respect size requirements for split.", + " The training set must contain all of the unique" + " labels that exist in the dataset.", + ) + else: + n_i[j] -= 1 + t_i[j] += 1 + + for i in range(n_classes): + permutation = rng.permutation(class_counts[i]) + perm_indices_class_i = class_indices[i].take(permutation, mode="clip") + + train.extend(perm_indices_class_i[: n_i[i]]) + test.extend(perm_indices_class_i[n_i[i]: n_i[i] + t_i[i]]) + + train = rng.permutation(train) + test = rng.permutation(test) + + yield train, test + + def get_dataset_compression_mapping( memory_limit: int, dataset_compression: Union[bool, Mapping[str, Any]] @@ -137,8 +239,8 @@ def validate_dataset_compression_arg( f"\nmemory_allocation = {memory_allocation}" f"\ndataset_compression = {dataset_compression}" ) - # convert to int so we can directly use - dataset_compression["memory_allocation"] = floor(memory_allocation * memory_limit) + # convert to required memory so we can directly use + dataset_compression["memory_allocation"] = memory_allocation * memory_limit # "methods" must be non-empty sequence if ( @@ -266,6 +368,97 @@ def reduce_precision( return X, reduced_dtypes, dtypes +def subsample( + X: DatasetCompressionInputType, + is_classification: bool, + sample_size: Union[float, int], + y: Optional[SupportedTargetTypes] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, +) -> Tuple[DatasetCompressionInputType, SupportedTargetTypes]: + """Subsamples data returning the same type as it recieved. + + If `is_classification`, we split using a stratified shuffle split which + preserves unique labels in the training set. + + NOTE: + It's highly unadvisable to use lists here. In order to preserve types, + we convert to a numpy array and then back to a list. + + NOTE2: + Interestingly enough, StratifiedShuffleSplut and descendants don't support + sparse `y` in `split(): _check_array` call. Hence, neither do we. + + Args: + X: DatasetCompressionInputType + The X's to subsample + y: SupportedTargetTypes + The Y's to subsample + is_classification: bool + Whether this is classification data or regression data. Required for + knowing how to split. + sample_size: float | int + If float, percentage of data to take otherwise if int, an absolute + count of samples to take. + random_state: int | RandomState = None + The random state to pass to the splitted + + Returns: + (DatasetCompressionInputType, SupportedTargetTypes) + The X and y subsampled according to sample_size + """ + + if isinstance(X, List): + X = np.asarray(X) + if isinstance(y, List): + y = np.asarray(y) + + if is_classification and y is not None: + splitter = CustomStratifiedShuffleSplit( + train_size=sample_size, random_state=random_state + ) + indices_to_keep, _ = next(splitter.split(X=X, y=y)) + X, y = _subsample_by_indices(X, y, indices_to_keep) + + elif y is None: + X, _ = train_test_split( # type: ignore + X, + train_size=sample_size, + random_state=random_state, + ) + else: + X, _, y, _ = train_test_split( # type: ignore + X, + y, + train_size=sample_size, + random_state=random_state, + ) + + return X, y + + +def _subsample_by_indices( + X: DatasetCompressionInputType, + y: SupportedTargetTypes, + indices_to_keep: np.ndarray +) -> Tuple[DatasetCompressionInputType, SupportedTargetTypes]: + """ + subsample data by given indices + """ + if ispandas(X): + idxs = X.index[indices_to_keep] + X = X.loc[idxs] + else: + X = X[indices_to_keep] + + if ispandas(y): + # Ifnoring types as mypy does not infer y as dataframe. + idxs = y.index[indices_to_keep] # type: ignore [index] + y = y.loc[idxs] # type: ignore [union-attr] + else: + y = y[indices_to_keep] + return X, y + + def megabytes(arr: DatasetCompressionInputType) -> float: if isinstance(arr, np.ndarray): @@ -283,8 +476,11 @@ def megabytes(arr: DatasetCompressionInputType) -> float: def reduce_dataset_size_if_too_large( X: DatasetCompressionInputType, - memory_allocation: int, - methods: List[str] = ['precision'], + memory_allocation: Union[int, float], + is_classification: bool, + random_state: Union[int, np.random.RandomState], + y: Optional[SupportedTargetTypes] = None, + methods: List[str] = ['precision', 'subsample'], ) -> DatasetCompressionInputType: f""" Reduces the size of the dataset if it's too close to the memory limit. @@ -305,15 +501,20 @@ def reduce_dataset_size_if_too_large( X: DatasetCompressionInputType The features of the dataset. - methods: List[str] = ['precision'] + methods (List[str] = ['precision', 'subsample']): A list of operations that are permitted to be performed to reduce the size of the dataset. **precision** - Reduce the precision of float types + Reduce the precision of float types + + **subsample** + Reduce the amount of samples of the dataset such that it fits into the allocated + memory. Ensures stratification and that unique labels are present - memory_allocation: int + + memory_allocation (Union[int, float]): The amount of memory to allocate to the dataset. It should specify an absolute amount. @@ -323,17 +524,44 @@ def reduce_dataset_size_if_too_large( """ for method in methods: + if megabytes(X) <= memory_allocation: + break if method == 'precision': # If the dataset is too big for the allocated memory, # we then try to reduce the precision if it's a high precision dataset - if megabytes(X) > memory_allocation: - X, reduced_dtypes, dtypes = reduce_precision(X) - warnings.warn( - f'Dataset too large for allocated memory {memory_allocation}MB, ' - f'reduced the precision from {dtypes} to {reduced_dtypes}', - ) + X, reduced_dtypes, dtypes = reduce_precision(X) + warnings.warn( + f'Dataset too large for allocated memory {memory_allocation}MB, ' + f'reduced the precision from {dtypes} to {reduced_dtypes}', + ) + elif method == "subsample": + # If the dataset is still too big such that we couldn't fit + # into the allocated memory, we subsample it so that it does + + n_samples_before = X.shape[0] + sample_percentage = memory_allocation / megabytes(X) + + # NOTE: type ignore + # + # Tried the generic `def subsample(X: T) -> T` approach but it was + # failing elsewhere, keeping it simple for now + X, y = subsample( # type: ignore + X, + y=y, + sample_size=sample_percentage, + is_classification=is_classification, + random_state=random_state, + ) + + n_samples_after = X.shape[0] + warnings.warn( + f"Dataset too large for allocated memory {memory_allocation}MB," + f" reduced number of samples from {n_samples_before} to" + f" {n_samples_after}." + ) + else: raise ValueError(f"Unknown operation `{method}`") - return X + return X, y diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index c1008b3ba..1c3a74068 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.tensorboard.writer import SummaryWriter -from autoPyTorch.constants import STRING_TO_TASK_TYPES +from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.base_component import ( @@ -257,6 +257,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic if 'optimize_metric' in X and X['optimize_metric'] not in [m.name for m in metrics]: metrics.extend(get_metrics(dataset_properties=X['dataset_properties'], names=[X['optimize_metric']])) additional_losses = X['additional_losses'] if 'additional_losses' in X else None + + labels = self._get_train_label(X) + self.choice.prepare( model=X['network'], metrics=metrics, @@ -268,7 +271,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic metrics_during_training=X['metrics_during_training'], scheduler=X['lr_scheduler'], task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']], - labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]], + labels=labels, step_interval=X['step_interval'] ) total_parameter_count, trainable_parameter_count = self.count_parameters(X['network']) @@ -381,6 +384,21 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic return self + def _get_train_label(self, X: Dict[str, Any]) -> List[int]: + """ + Verifies and validates the labels from train split. + """ + # Ensure that the split is not missing any class. + labels: List[int] = X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]] + if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS: + unique_labels = len(np.unique(labels)) + if unique_labels < X['dataset_properties']['output_shape']: + raise ValueError(f"Expected number of unique labels {unique_labels} in train split: {X['split_id']}" + f" to be = num_classes {X['dataset_properties']['output_shape']}." + f" Consider using stratified splitting strategies.") + + return labels + def _load_best_weights_and_clean_checkpoints(self, X: Dict[str, Any]) -> None: """ Load the best model until the last epoch and delete all the files for checkpoints. diff --git a/test/conftest.py b/test/conftest.py index 604d8f00e..4496594ab 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -15,6 +15,8 @@ import pytest +from scipy import sparse + from sklearn.datasets import fetch_openml, make_classification, make_regression import torch @@ -466,3 +468,150 @@ def loss_details(request): @pytest.fixture def n_samples(): return N_SAMPLES + + +# Fixtures for input validators. By default all elements have 100 datapoints +@pytest.fixture +def input_data_featuretest(request): + if request.param == 'numpy_categoricalonly_nonan': + return np.random.randint(10, size=(100, 10)) + elif request.param == 'numpy_numericalonly_nonan': + return np.random.uniform(10, size=(100, 10)) + elif request.param == 'numpy_mixed_nonan': + return np.column_stack([ + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 3)), + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 1)), + ]) + elif request.param == 'numpy_string_nonan': + return np.array([ + ['a', 'b', 'c', 'a', 'b', 'c'], + ['a', 'b', 'd', 'r', 'b', 'c'], + ]) + elif request.param == 'numpy_categoricalonly_nan': + array = np.random.randint(10, size=(100, 10)).astype('float') + array[50, 0:5] = np.nan + return array + elif request.param == 'numpy_numericalonly_nan': + array = np.full(fill_value=10.0, shape=(100, 10), dtype=np.float64) + array[50, 0:5] = np.nan + # Somehow array is changed to dtype object after np.nan + return array.astype('float') + elif request.param == 'numpy_mixed_nan': + array = np.column_stack([ + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 3)), + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 1)), + ]) + array[50, 0:5] = np.nan + return array + elif request.param == 'numpy_string_nan': + return np.array([ + ['a', 'b', 'c', 'a', 'b', 'c'], + [np.nan, 'b', 'd', 'r', 'b', 'c'], + ]) + elif request.param == 'pandas_categoricalonly_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category') + elif request.param == 'pandas_numericalonly_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='float') + elif request.param == 'pandas_mixed_nonan': + frame = pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category') + frame['B'] = pd.to_numeric(frame['B']) + return frame + elif request.param == 'pandas_categoricalonly_nan': + return pd.DataFrame([ + {'A': 1, 'B': 2, 'C': np.nan}, + {'A': 3, 'C': np.nan}, + ], dtype='category') + elif request.param == 'pandas_numericalonly_nan': + return pd.DataFrame([ + {'A': 1, 'B': 2, 'C': np.nan}, + {'A': 3, 'C': np.nan}, + ], dtype='float') + elif request.param == 'pandas_mixed_nan': + frame = pd.DataFrame([ + {'A': 1, 'B': 2, 'C': 8}, + {'A': 3, 'B': 4}, + ], dtype='category') + frame['B'] = pd.to_numeric(frame['B']) + return frame + elif request.param == 'pandas_string_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='string') + elif request.param == 'list_categoricalonly_nonan': + return [ + ['a', 'b', 'c', 'd'], + ['e', 'f', 'c', 'd'], + ] + elif request.param == 'list_numericalonly_nonan': + return [ + [1, 2, 3, 4], + [5, 6, 7, 8] + ] + elif request.param == 'list_mixed_nonan': + return [ + ['a', 2, 3, 4], + ['b', 6, 7, 8] + ] + elif request.param == 'list_categoricalonly_nan': + return [ + ['a', 'b', 'c', np.nan], + ['e', 'f', 'c', 'd'], + ] + elif request.param == 'list_numericalonly_nan': + return [ + [1, 2, 3, np.nan], + [5, 6, 7, 8] + ] + elif request.param == 'list_mixed_nan': + return [ + ['a', np.nan, 3, 4], + ['b', 6, 7, 8] + ] + elif 'sparse' in request.param: + # We expect the names to be of the type sparse_csc_nonan + sparse_, type_, nan_ = request.param.split('_') + if 'nonan' in nan_: + data = np.ones(3) + else: + data = np.array([1, 2, np.nan]) + + # Then the type of sparse + row_ind = np.array([0, 1, 2]) + col_ind = np.array([1, 2, 1]) + if 'csc' in type_: + return sparse.csc_matrix((data, (row_ind, col_ind))) + elif 'csr' in type_: + return sparse.csr_matrix((data, (row_ind, col_ind))) + elif 'coo' in type_: + return sparse.coo_matrix((data, (row_ind, col_ind))) + elif 'bsr' in type_: + return sparse.bsr_matrix((data, (row_ind, col_ind))) + elif 'lil' in type_: + return sparse.lil_matrix((data)) + elif 'dok' in type_: + return sparse.dok_matrix(np.vstack((data, data, data))) + elif 'dia' in type_: + return sparse.dia_matrix(np.vstack((data, data, data))) + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + elif 'openml' in request.param: + _, openml_id = request.param.split('_') + X, y = fetch_openml(data_id=int(openml_id), + return_X_y=True, as_frame=True) + return X + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 3d352d765..2daa271b7 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -13,154 +13,6 @@ import sklearn.model_selection from autoPyTorch.data.tabular_feature_validator import TabularFeatureValidator -from autoPyTorch.data.utils import megabytes - - -# Fixtures to be used in this class. By default all elements have 100 datapoints -@pytest.fixture -def input_data_featuretest(request): - if request.param == 'numpy_categoricalonly_nonan': - return np.random.randint(10, size=(100, 10)) - elif request.param == 'numpy_numericalonly_nonan': - return np.random.uniform(10, size=(100, 10)) - elif request.param == 'numpy_mixed_nonan': - return np.column_stack([ - np.random.uniform(10, size=(100, 3)), - np.random.randint(10, size=(100, 3)), - np.random.uniform(10, size=(100, 3)), - np.random.randint(10, size=(100, 1)), - ]) - elif request.param == 'numpy_string_nonan': - return np.array([ - ['a', 'b', 'c', 'a', 'b', 'c'], - ['a', 'b', 'd', 'r', 'b', 'c'], - ]) - elif request.param == 'numpy_categoricalonly_nan': - array = np.random.randint(10, size=(100, 10)).astype('float') - array[50, 0:5] = np.nan - return array - elif request.param == 'numpy_numericalonly_nan': - array = np.full(fill_value=10.0, shape=(100, 10), dtype=np.float64) - array[50, 0:5] = np.nan - # Somehow array is changed to dtype object after np.nan - return array.astype('float') - elif request.param == 'numpy_mixed_nan': - array = np.column_stack([ - np.random.uniform(10, size=(100, 3)), - np.random.randint(10, size=(100, 3)), - np.random.uniform(10, size=(100, 3)), - np.random.randint(10, size=(100, 1)), - ]) - array[50, 0:5] = np.nan - return array - elif request.param == 'numpy_string_nan': - return np.array([ - ['a', 'b', 'c', 'a', 'b', 'c'], - [np.nan, 'b', 'd', 'r', 'b', 'c'], - ]) - elif request.param == 'pandas_categoricalonly_nonan': - return pd.DataFrame([ - {'A': 1, 'B': 2}, - {'A': 3, 'B': 4}, - ], dtype='category') - elif request.param == 'pandas_numericalonly_nonan': - return pd.DataFrame([ - {'A': 1, 'B': 2}, - {'A': 3, 'B': 4}, - ], dtype='float') - elif request.param == 'pandas_mixed_nonan': - frame = pd.DataFrame([ - {'A': 1, 'B': 2}, - {'A': 3, 'B': 4}, - ], dtype='category') - frame['B'] = pd.to_numeric(frame['B']) - return frame - elif request.param == 'pandas_categoricalonly_nan': - return pd.DataFrame([ - {'A': 1, 'B': 2, 'C': np.nan}, - {'A': 3, 'C': np.nan}, - ], dtype='category') - elif request.param == 'pandas_numericalonly_nan': - return pd.DataFrame([ - {'A': 1, 'B': 2, 'C': np.nan}, - {'A': 3, 'C': np.nan}, - ], dtype='float') - elif request.param == 'pandas_mixed_nan': - frame = pd.DataFrame([ - {'A': 1, 'B': 2, 'C': 8}, - {'A': 3, 'B': 4}, - ], dtype='category') - frame['B'] = pd.to_numeric(frame['B']) - return frame - elif request.param == 'pandas_string_nonan': - return pd.DataFrame([ - {'A': 1, 'B': 2}, - {'A': 3, 'B': 4}, - ], dtype='string') - elif request.param == 'list_categoricalonly_nonan': - return [ - ['a', 'b', 'c', 'd'], - ['e', 'f', 'c', 'd'], - ] - elif request.param == 'list_numericalonly_nonan': - return [ - [1, 2, 3, 4], - [5, 6, 7, 8] - ] - elif request.param == 'list_mixed_nonan': - return [ - ['a', 2, 3, 4], - ['b', 6, 7, 8] - ] - elif request.param == 'list_categoricalonly_nan': - return [ - ['a', 'b', 'c', np.nan], - ['e', 'f', 'c', 'd'], - ] - elif request.param == 'list_numericalonly_nan': - return [ - [1, 2, 3, np.nan], - [5, 6, 7, 8] - ] - elif request.param == 'list_mixed_nan': - return [ - ['a', np.nan, 3, 4], - ['b', 6, 7, 8] - ] - elif 'sparse' in request.param: - # We expect the names to be of the type sparse_csc_nonan - sparse_, type_, nan_ = request.param.split('_') - if 'nonan' in nan_: - data = np.ones(3) - else: - data = np.array([1, 2, np.nan]) - - # Then the type of sparse - row_ind = np.array([0, 1, 2]) - col_ind = np.array([1, 2, 1]) - if 'csc' in type_: - return sparse.csc_matrix((data, (row_ind, col_ind))) - elif 'csr' in type_: - return sparse.csr_matrix((data, (row_ind, col_ind))) - elif 'coo' in type_: - return sparse.coo_matrix((data, (row_ind, col_ind))) - elif 'bsr' in type_: - return sparse.bsr_matrix((data, (row_ind, col_ind))) - elif 'lil' in type_: - return sparse.lil_matrix((data)) - elif 'dok' in type_: - return sparse.dok_matrix(np.vstack((data, data, data))) - elif 'dia' in type_: - return sparse.dia_matrix(np.vstack((data, data, data))) - else: - ValueError("Unsupported indirect fixture {}".format(request.param)) - elif 'openml' in request.param: - _, openml_id = request.param.split('_') - X, y = sklearn.datasets.fetch_openml(data_id=int(openml_id), - return_X_y=True, as_frame=True) - return X - else: - ValueError("Unsupported indirect fixture {}".format(request.param)) # Actual checks for the features @@ -509,10 +361,6 @@ def test_featurevalidator_new_data_after_fit(openml_id, transformed_X = validator.transform(X_test) # Basic Checking - if sparse.issparse(input_data_featuretest): - assert sparse.issparse(transformed_X) - else: - assert isinstance(transformed_X, np.ndarray) assert np.shape(X_test) == np.shape(transformed_X) # And then check proper error messages @@ -558,47 +406,3 @@ def test_comparator(): key=functools.cmp_to_key(validator._comparator) ) assert ans == feat_type - - -# Actual checks for the features -@pytest.mark.parametrize( - 'input_data_featuretest', - ( - 'numpy_numericalonly_nonan', - 'numpy_numericalonly_nan', - 'numpy_mixed_nan', - 'pandas_numericalonly_nan', - 'sparse_bsr_nonan', - 'sparse_bsr_nan', - 'sparse_coo_nonan', - 'sparse_coo_nan', - 'sparse_csc_nonan', - 'sparse_csc_nan', - 'sparse_csr_nonan', - 'sparse_csr_nan', - 'sparse_dia_nonan', - 'sparse_dia_nan', - 'sparse_dok_nonan', - 'sparse_dok_nan', - 'openml_40981', # Australian - ), - indirect=True -) -def test_featurevalidator_reduce_precision(input_data_featuretest): - X_train, X_test = sklearn.model_selection.train_test_split( - input_data_featuretest, test_size=0.1, random_state=1) - validator = TabularFeatureValidator(dataset_compression={'memory_allocation': 0, 'methods': ['precision']}) - validator.fit(X_train=X_train) - transformed_X_train = validator.transform(X_train.copy()) - - assert validator._reduced_dtype is not None - assert megabytes(transformed_X_train) < megabytes(X_train) - - transformed_X_test = validator.transform(X_test.copy()) - assert megabytes(transformed_X_test) < megabytes(X_test) - if hasattr(transformed_X_train, 'iloc'): - assert all(transformed_X_train.dtypes == transformed_X_test.dtypes) - assert all(transformed_X_train.dtypes == validator._precision) - else: - assert transformed_X_train.dtype == transformed_X_test.dtype - assert transformed_X_test.dtype == validator._reduced_dtype diff --git a/test/test_data/test_utils.py b/test/test_data/test_utils.py index 505860a94..4269c4e5f 100644 --- a/test/test_data/test_utils.py +++ b/test/test_data/test_utils.py @@ -1,19 +1,34 @@ +import warnings +from test.test_data.utils import convert, dtype, size from typing import Mapping import numpy as np -from pandas.testing import assert_frame_equal +import pandas as pd import pytest +from scipy.sparse import csr_matrix + from sklearn.datasets import fetch_openml +from autoPyTorch.constants import ( + BINARY, + CLASSIFICATION_TASKS, + CONTINUOUS, + CONTINUOUSMULTIOUTPUT, + MULTICLASS, + MULTICLASSMULTIOUTPUT, + TABULAR_CLASSIFICATION, + TABULAR_REGRESSION +) from autoPyTorch.data.utils import ( default_dataset_compression_arg, get_dataset_compression_mapping, megabytes, reduce_dataset_size_if_too_large, reduce_precision, + subsample, validate_dataset_compression_arg ) from autoPyTorch.utils.common import subsampler @@ -22,14 +37,90 @@ @pytest.mark.parametrize('openmlid', [2, 40984]) @pytest.mark.parametrize('as_frame', [True, False]) def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples): - X, _ = fetch_openml(data_id=openmlid, return_X_y=True, as_frame=as_frame) + X, y = fetch_openml(data_id=openmlid, return_X_y=True, as_frame=as_frame) X = subsampler(data=X, x=range(n_samples)) + y = subsampler(data=y, x=range(n_samples)) + + X_converted, y_converted = reduce_dataset_size_if_too_large( + X.copy(), + y=y.copy(), + is_classification=True, + random_state=1, + memory_allocation=0.001) + + assert X_converted.shape[0] < X.shape[0] + assert y_converted.shape[0] < y.shape[0] - X_converted = reduce_dataset_size_if_too_large(X.copy(), memory_allocation=0) - np.allclose(X, X_converted) if not as_frame else assert_frame_equal(X, X_converted, check_dtype=False) assert megabytes(X_converted) < megabytes(X) +@pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)]) +@pytest.mark.parametrize("x_type", [list, np.ndarray, csr_matrix, pd.DataFrame]) +@pytest.mark.parametrize( + "y, task, output", + [ + (np.asarray([0] * 15 + [1] * 15), TABULAR_CLASSIFICATION, BINARY), + (np.asarray([0] * 10 + [1] * 10 + [2] * 10), TABULAR_CLASSIFICATION, MULTICLASS), + (np.asarray([[1, 0, 1]] * 30), TABULAR_CLASSIFICATION, MULTICLASSMULTIOUTPUT), + (np.asarray([1.0] * 30), TABULAR_REGRESSION, CONTINUOUS), + (np.asarray([[1.0, 1.0, 1.0]] * 30), TABULAR_REGRESSION, CONTINUOUSMULTIOUTPUT), + ], +) +@pytest.mark.parametrize("y_type", [list, np.ndarray, pd.DataFrame, pd.Series]) +@pytest.mark.parametrize("random_state", [0]) +@pytest.mark.parametrize("sample_size", [0.25, 0.5, 5, 10]) +def test_subsample_validity(X, x_type, y, y_type, random_state, sample_size, task, output): + """Asserts the validity of the function with all valid types + We want to make sure that `subsample` works correctly with all the types listed + as x_type and y_type. + We also want to make sure it works with all kinds of target types. + The output should maintain the types, and subsample the correct amount. + (test adapted from autosklearn) + """ + assert len(X) == len(y) # Make sure our test data is correct + + if y_type == pd.Series and output in [ + MULTICLASSMULTIOUTPUT, + CONTINUOUSMULTIOUTPUT, + ]: + # We can't have a pd.Series with multiple values as it's 1 dimensional + pytest.skip("Can't have pd.Series as y when task is n-dimensional") + + # Convert our data to its given x_type or y_type + X = convert(X, x_type) + y = convert(y, y_type) + + # Subsample the data, ignoring any warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + X_sampled, y_sampled = subsample( + X, + y=y, + random_state=random_state, + sample_size=sample_size, + is_classification=task in CLASSIFICATION_TASKS, + ) + + # Check that the types of X remain the same after subsampling + if isinstance(X, pd.DataFrame): + # Dataframe can have multiple types, one per column + assert list(dtype(X_sampled)) == list(dtype(X)) + else: + assert dtype(X_sampled) == dtype(X) + + # Check that the types of y remain the same after subsampling + if isinstance(y, pd.DataFrame): + assert list(dtype(y_sampled)) == list(dtype(y)) + else: + assert dtype(y_sampled) == dtype(y) + + # check the right amount of samples were taken + if sample_size < 1: + assert size(X_sampled) == int(sample_size * size(X)) + else: + assert size(X_sampled) == sample_size + + def test_validate_dataset_compression_arg(): data_compression_args = validate_dataset_compression_arg({}, 10) @@ -37,8 +128,8 @@ def test_validate_dataset_compression_arg(): # to fill in case args is empty assert data_compression_args is not None - # assert memory allocation is an integer after validation - assert isinstance(data_compression_args['memory_allocation'], int) + # assert memory allocation is a float after validation + assert isinstance(data_compression_args['memory_allocation'], float) # check whether the function raises an error # in case an unknown key is in args @@ -120,8 +211,8 @@ def test_unsupported_errors(): ['a', 'b', 'c', 'a', 'b', 'c'], ['a', 'b', 'd', 'r', 'b', 'c']]) with pytest.raises(ValueError, match=r'X.dtype = .*'): - reduce_dataset_size_if_too_large(X, 0) + reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) X = [[1, 2], [2, 3]] with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'): - reduce_dataset_size_if_too_large(X, 0) + reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index 482c99769..f7755e35e 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -10,6 +10,7 @@ import sklearn.model_selection from autoPyTorch.data.tabular_validator import TabularInputValidator +from autoPyTorch.data.utils import megabytes @pytest.mark.parametrize('openmlid', [2, 40975, 40984]) @@ -137,3 +138,50 @@ def test_validation_unsupported(): X=np.array([[0, 1, 0], [0, 1, 1]]), y=np.array([0, 1]), ) + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'numpy_numericalonly_nonan', + 'numpy_numericalonly_nan', + 'numpy_mixed_nan', + 'pandas_numericalonly_nan', + 'sparse_bsr_nonan', + 'sparse_bsr_nan', + 'sparse_coo_nonan', + 'sparse_coo_nan', + 'sparse_csc_nonan', + 'sparse_csc_nan', + 'sparse_csr_nonan', + 'sparse_csr_nan', + 'sparse_dia_nonan', + 'sparse_dia_nan', + 'sparse_dok_nonan', + 'sparse_dok_nan', + 'openml_40981', # Australian + ), + indirect=True +) +def test_featurevalidator_dataset_compression(input_data_featuretest): + n_samples = input_data_featuretest.shape[0] + input_data_targets = np.random.random_sample((n_samples)) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + input_data_featuretest, input_data_targets, test_size=0.1, random_state=1) + validator = TabularInputValidator( + dataset_compression={'memory_allocation': 0.8 * megabytes(X_train), 'methods': ['precision', 'subsample']} + ) + validator.fit(X_train=X_train, y_train=y_train) + transformed_X_train, _ = validator.transform(X_train.copy(), y_train.copy()) + + assert validator._reduced_dtype is not None + assert megabytes(transformed_X_train) < megabytes(X_train) + + transformed_X_test, _ = validator.transform(X_test.copy(), y_test.copy()) + assert megabytes(transformed_X_test) < megabytes(X_test) + if hasattr(transformed_X_train, 'iloc'): + assert all(transformed_X_train.dtypes == transformed_X_test.dtypes) + assert all(transformed_X_train.dtypes == validator._precision) + else: + assert transformed_X_train.dtype == transformed_X_test.dtype + assert transformed_X_test.dtype == validator._reduced_dtype diff --git a/test/test_data/utils.py b/test/test_data/utils.py new file mode 100644 index 000000000..f1fff440a --- /dev/null +++ b/test/test_data/utils.py @@ -0,0 +1,34 @@ +from typing import List + +import numpy as np + +import pandas as pd + +from scipy.sparse import spmatrix + + +def convert(arr, objtype): + if objtype == np.ndarray: + return arr + elif objtype == list: + return arr.tolist() + else: + return objtype(arr) + + +# Function to get the type of an obj +def dtype(obj): + if isinstance(obj, List): + return type(obj[0][0]) if isinstance(obj[0], List) else type(obj[0]) + elif isinstance(obj, pd.DataFrame): + return obj.dtypes + else: + return obj.dtype + + +# Function to get the size of an object +def size(obj): + if isinstance(obj, spmatrix): # spmatrix doesn't support __len__ + return obj.shape[0] if obj.shape[0] > 1 else obj.shape[1] + else: + return len(obj) diff --git a/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py b/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py index dd0a08d26..c4c03641c 100644 --- a/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py +++ b/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py @@ -113,7 +113,7 @@ def test_pipeline_fit_include(self, fit_dictionary_tabular, preprocessor): pipeline.fit(fit_dictionary_tabular) except Exception as e: if ( - ("must be non-negative" or "contains negative values") in e.args[0] + ("must be non-negative" in e.args[0] or "contains negative values" in e.args[0]) and not fit_dictionary_tabular['dataset_properties']['issigned'] ): pytest.skip("Failure because scaler made data nonnegative.") diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index 6b277d36d..6deda30ad 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -423,7 +423,7 @@ def test_get_set_config_space(self): def test_early_stopping(): - dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'} + dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary', 'output_shape': 0} trainer_choice = TrainerChoice(dataset_properties=dataset_properties) def dummy_performance(*args, **kwargs): diff --git a/test/test_pipeline/test_tabular_regression.py b/test/test_pipeline/test_tabular_regression.py index 75dc8a415..c6c475b91 100644 --- a/test/test_pipeline/test_tabular_regression.py +++ b/test/test_pipeline/test_tabular_regression.py @@ -1,6 +1,7 @@ import os import re import unittest +import unittest.mock from ConfigSpace.hyperparameters import ( CategoricalHyperparameter,