2121 DEFAULT_RESAMPLING_PARAMETERS ,
2222 HoldOutFunc ,
2323 HoldOutFuncs ,
24- HoldoutValTypes
24+ HoldoutValTypes ,
25+ get_no_resampling_validators ,
26+ NoResamplingStrategyTypes ,
27+ NO_RESAMPLING_FN
2528)
2629from autoPyTorch .utils .common import FitRequirement
2730
@@ -78,7 +81,9 @@ def __init__(
7881 dataset_name : Optional [str ] = None ,
7982 val_tensors : Optional [BaseDatasetInputType ] = None ,
8083 test_tensors : Optional [BaseDatasetInputType ] = None ,
81- resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
84+ resampling_strategy : Union [CrossValTypes ,
85+ HoldoutValTypes ,
86+ NoResamplingStrategyTypes ] = HoldoutValTypes .holdout_validation ,
8287 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
8388 shuffle : Optional [bool ] = True ,
8489 seed : Optional [int ] = 42 ,
@@ -95,7 +100,7 @@ def __init__(
95100 validation data
96101 test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
97102 test data
98- resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
103+ resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes ]),
99104 (default=HoldoutValTypes.holdout_validation):
100105 strategy to split the training data.
101106 resampling_strategy_args (Optional[Dict[str, Any]]): arguments
@@ -117,9 +122,16 @@ def __init__(
117122 if not hasattr (train_tensors [0 ], 'shape' ):
118123 type_check (train_tensors , val_tensors )
119124 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
125+ < << << << HEAD
120126 self .cross_validators : Dict [str , CrossValFunc ] = {}
121127 self .holdout_validators : Dict [str , HoldOutFunc ] = {}
122128 self .random_state = np .random .RandomState (seed = seed )
129+ == == == =
130+ self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
131+ self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
132+ self .no_resampling_validators : Dict [str , NO_RESAMPLING_FN ] = {}
133+ self .rng = np .random .RandomState (seed = seed )
134+ >> >> >> > Fix mypy and flake
123135 self .shuffle = shuffle
124136 self .resampling_strategy = resampling_strategy
125137 self .resampling_strategy_args = resampling_strategy_args
@@ -144,6 +156,8 @@ def __init__(
144156 # Make sure cross validation splits are created once
145157 self .cross_validators = CrossValFuncs .get_cross_validators (* CrossValTypes )
146158 self .holdout_validators = HoldOutFuncs .get_holdout_validators (* HoldoutValTypes )
159+ self .no_resampling_validators = get_no_resampling_validators (* NoResamplingStrategyTypes )
160+
147161 self .splits = self .get_splits_from_resampling_strategy ()
148162
149163 # We also need to be able to transform the data, be it for pre-processing
@@ -211,7 +225,7 @@ def __len__(self) -> int:
211225 def _get_indices (self ) -> np .ndarray :
212226 return self .random_state .permutation (len (self )) if self .shuffle else np .arange (len (self ))
213227
214- def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
228+ def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], Optional [ List [int ] ]]]:
215229 """
216230 Creates a set of splits based on a resampling strategy provided
217231
@@ -242,6 +256,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
242256 num_splits = cast (int , num_splits ),
243257 )
244258 )
259+ elif isinstance (self .resampling_strategy , NoResamplingStrategyTypes ):
260+ splits .append ((self .no_resampling_validators [self .resampling_strategy .name ](self ._get_indices ()), None ))
245261 else :
246262 raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
247263 return splits
@@ -313,7 +329,7 @@ def create_holdout_val_split(
313329 self .random_state , val_share , self ._get_indices (), ** kwargs )
314330 return train , val
315331
316- def get_dataset_for_training (self , split_id : int ) -> Tuple [ Dataset , Dataset ] :
332+ def get_dataset_for_training (self , split_id : int , train : bool ) -> Dataset :
317333 """
318334 The above split methods employ the Subset to internally subsample the whole dataset.
319335
@@ -327,8 +343,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
327343 Dataset: the reduced dataset to be used for testing
328344 """
329345 # Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
330- return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
331- TransformSubset (self , self .splits [split_id ][1 ], train = False ))
346+ return TransformSubset (self , self .splits [split_id ][0 ], train = train )
332347
333348 def replace_data (self , X_train : BaseDatasetInputType ,
334349 X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
0 commit comments