1
1
# -*- encoding: utf-8 -*-
2
2
import logging
3
- from typing import Any , Mapping , Optional , Union
3
+ from typing import Optional , Tuple , Union
4
+
5
+ import numpy as np
6
+
7
+ from scipy .sparse import issparse
4
8
5
9
from autoPyTorch .data .base_validator import BaseInputValidator
6
- from autoPyTorch .data .tabular_feature_validator import TabularFeatureValidator
7
- from autoPyTorch .data .tabular_target_validator import TabularTargetValidator
10
+ from autoPyTorch .data .tabular_feature_validator import SupportedFeatTypes , TabularFeatureValidator
11
+ from autoPyTorch .data .tabular_target_validator import SupportedTargetTypes , TabularTargetValidator
12
+ from autoPyTorch .data .utils import (
13
+ DatasetCompressionInputType ,
14
+ DatasetCompressionSpec ,
15
+ DatasetDTypeContainerType ,
16
+ reduce_dataset_size_if_too_large
17
+ )
18
+ from autoPyTorch .utils .common import ispandas
8
19
from autoPyTorch .utils .logging_ import PicklableClientLogger , get_named_client_logger
9
20
10
21
@@ -27,16 +38,22 @@ class TabularInputValidator(BaseInputValidator):
27
38
target_validator (TargetValidator):
28
39
A TargetValidator instance used to validate and encode (in case of classification)
29
40
the target values
41
+ dataset_compression (Optional[DatasetCompressionSpec]):
42
+ specifications for dataset compression. For more info check
43
+ documentation for `BaseTask.get_dataset`.
30
44
"""
31
45
def __init__ (
32
46
self ,
33
47
is_classification : bool = False ,
34
48
logger_port : Optional [int ] = None ,
35
- dataset_compression : Optional [Mapping [str , Any ]] = None ,
36
- ) -> None :
49
+ dataset_compression : Optional [DatasetCompressionSpec ] = None ,
50
+ seed : int = 42 ,
51
+ ):
52
+ self .dataset_compression = dataset_compression
53
+ self ._reduced_dtype : Optional [DatasetDTypeContainerType ] = None
37
54
self .is_classification = is_classification
38
55
self .logger_port = logger_port
39
- self .dataset_compression = dataset_compression
56
+ self .seed = seed
40
57
if self .logger_port is not None :
41
58
self .logger : Union [logging .Logger , PicklableClientLogger ] = get_named_client_logger (
42
59
name = 'Validation' ,
@@ -46,10 +63,59 @@ def __init__(
46
63
self .logger = logging .getLogger ('Validation' )
47
64
48
65
self .feature_validator = TabularFeatureValidator (
49
- dataset_compression = self .dataset_compression ,
50
66
logger = self .logger )
51
67
self .target_validator = TabularTargetValidator (
52
68
is_classification = self .is_classification ,
53
69
logger = self .logger
54
70
)
55
71
self ._is_fitted = False
72
+
73
+ def _compress_dataset (
74
+ self ,
75
+ X : DatasetCompressionInputType ,
76
+ y : SupportedTargetTypes ,
77
+ ) -> DatasetCompressionInputType :
78
+ """
79
+ Compress the dataset. This function ensures that
80
+ the testing data is converted to the same dtype as
81
+ the training data.
82
+ See `autoPyTorch.data.utils.reduce_dataset_size_if_too_large`
83
+ for more information.
84
+
85
+ Args:
86
+ X (DatasetCompressionInputType):
87
+ features of dataset
88
+ y (SupportedTargetTypes):
89
+ targets of dataset
90
+ Returns:
91
+ DatasetCompressionInputType:
92
+ Compressed dataset.
93
+ """
94
+ is_dataframe = ispandas (X )
95
+ is_reducible_type = isinstance (X , np .ndarray ) or issparse (X ) or is_dataframe
96
+ if not is_reducible_type or self .dataset_compression is None :
97
+ return X , y
98
+ elif self ._reduced_dtype is not None :
99
+ X = X .astype (self ._reduced_dtype )
100
+ return X , y
101
+ else :
102
+ X , y = reduce_dataset_size_if_too_large (
103
+ X ,
104
+ y = y ,
105
+ is_classification = self .is_classification ,
106
+ random_state = self .seed ,
107
+ ** self .dataset_compression # type: ignore [arg-type]
108
+ )
109
+ self ._reduced_dtype = dict (X .dtypes ) if is_dataframe else X .dtype
110
+ return X , y
111
+
112
+ def transform (
113
+ self ,
114
+ X : SupportedFeatTypes ,
115
+ y : Optional [SupportedTargetTypes ] = None ,
116
+ ) -> Tuple [np .ndarray , Optional [np .ndarray ]]:
117
+
118
+ X , y = super ().transform (X , y )
119
+ X_reduced , y_reduced = self ._compress_dataset (X , y )
120
+
121
+ return X_reduced , y_reduced
0 commit comments