1
+ import os
2
+ import uuid
1
3
from abc import ABCMeta
2
4
from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
3
5
13
15
14
16
from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
15
17
from autoPyTorch .datasets .resampling_strategy import (
16
- CROSS_VAL_FN ,
18
+ CrossValFunc ,
19
+ CrossValFuncs ,
17
20
CrossValTypes ,
18
21
DEFAULT_RESAMPLING_PARAMETERS ,
19
- HOLDOUT_FN ,
20
- HoldoutValTypes ,
21
- get_cross_validators ,
22
- get_holdout_validators ,
23
- is_stratified ,
22
+ HoldOutFunc ,
23
+ HoldOutFuncs ,
24
+ HoldoutValTypes
24
25
)
25
- from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
26
+ from autoPyTorch .utils .common import FitRequirement
26
27
27
- BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
28
+ BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
28
29
29
30
30
31
def check_valid_data (data : Any ) -> None :
@@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
33
34
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
34
35
35
36
36
- def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
37
+ def type_check (train_tensors : BaseDatasetInputType ,
38
+ val_tensors : Optional [BaseDatasetInputType ] = None ) -> None :
37
39
"""To avoid unexpected behavior, we use loops over indices."""
38
40
for i in range (len (train_tensors )):
39
41
check_valid_data (train_tensors [i ])
@@ -49,8 +51,8 @@ class TransformSubset(Subset):
49
51
we require different transformation for each data point.
50
52
This class helps to take the subset of the dataset
51
53
with either training or validation transformation.
52
-
53
- We achieve so by adding a train flag to the pytorch subset
54
+ The TransformSubset allows to add train flags
55
+ while indexing the main dataset towards this goal.
54
56
55
57
Attributes:
56
58
dataset (BaseDataset/Dataset): Dataset to sample the subset
@@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
71
73
class BaseDataset (Dataset , metaclass = ABCMeta ):
72
74
def __init__ (
73
75
self ,
74
- train_tensors : BaseDatasetType ,
76
+ train_tensors : BaseDatasetInputType ,
75
77
dataset_name : Optional [str ] = None ,
76
- val_tensors : Optional [BaseDatasetType ] = None ,
77
- test_tensors : Optional [BaseDatasetType ] = None ,
78
+ val_tensors : Optional [BaseDatasetInputType ] = None ,
79
+ test_tensors : Optional [BaseDatasetInputType ] = None ,
78
80
resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
79
81
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
80
82
shuffle : Optional [bool ] = True ,
@@ -106,14 +108,16 @@ def __init__(
106
108
val_transforms (Optional[torchvision.transforms.Compose]):
107
109
Additional Transforms to be applied to the validation/test data
108
110
"""
109
- self .dataset_name = dataset_name if dataset_name is not None \
110
- else hash_array_or_matrix (train_tensors [0 ])
111
+ self .dataset_name = dataset_name
112
+
113
+ if self .dataset_name is None :
114
+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
111
115
112
116
if not hasattr (train_tensors [0 ], 'shape' ):
113
117
type_check (train_tensors , val_tensors )
114
118
self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
115
- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
116
- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
119
+ self .cross_validators : Dict [str , CrossValFunc ] = {}
120
+ self .holdout_validators : Dict [str , HoldOutFunc ] = {}
117
121
self .rng = np .random .RandomState (seed = seed )
118
122
self .shuffle = shuffle
119
123
self .resampling_strategy = resampling_strategy
@@ -134,8 +138,8 @@ def __init__(
134
138
self .is_small_preprocess = True
135
139
136
140
# Make sure cross validation splits are created once
137
- self .cross_validators = get_cross_validators (* CrossValTypes )
138
- self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
141
+ self .cross_validators = CrossValFuncs . get_cross_validators (* CrossValTypes )
142
+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (* HoldoutValTypes )
139
143
self .splits = self .get_splits_from_resampling_strategy ()
140
144
141
145
# We also need to be able to transform the data, be it for pre-processing
@@ -263,7 +267,7 @@ def create_cross_val_splits(
263
267
if not isinstance (cross_val_type , CrossValTypes ):
264
268
raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
265
269
kwargs = {}
266
- if is_stratified (cross_val_type ):
270
+ if cross_val_type . is_stratified ():
267
271
# we need additional information about the data for stratification
268
272
kwargs ["stratify" ] = self .train_tensors [- 1 ]
269
273
splits = self .cross_validators [cross_val_type .name ](
@@ -298,7 +302,7 @@ def create_holdout_val_split(
298
302
if not isinstance (holdout_val_type , HoldoutValTypes ):
299
303
raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
300
304
kwargs = {}
301
- if is_stratified (holdout_val_type ):
305
+ if holdout_val_type . is_stratified ():
302
306
# we need additional information about the data for stratification
303
307
kwargs ["stratify" ] = self .train_tensors [- 1 ]
304
308
train , val = self .holdout_validators [holdout_val_type .name ](val_share , self ._get_indices (), ** kwargs )
@@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
321
325
return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
322
326
TransformSubset (self , self .splits [split_id ][1 ], train = False ))
323
327
324
- def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
328
+ def replace_data (self , X_train : BaseDatasetInputType ,
329
+ X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
325
330
"""
326
331
To speed up the training of small dataset, early pre-processing of the data
327
332
can be made on the fly by the pipeline.
0 commit comments