@@ -82,13 +82,7 @@ def __init__(
8282 dataset_name : Optional [str ] = None ,
8383 val_tensors : Optional [BaseDatasetInputType ] = None ,
8484 test_tensors : Optional [BaseDatasetInputType ] = None ,
85- << << << < HEAD
8685 resampling_strategy : ResamplingStrategies = HoldoutValTypes .holdout_validation ,
87- == == == =
88- resampling_strategy : Union [CrossValTypes ,
89- HoldoutValTypes ,
90- NoResamplingStrategyTypes ] = HoldoutValTypes .holdout_validation ,
91- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
9286 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
9387 shuffle : Optional [bool ] = True ,
9488 seed : Optional [int ] = 42 ,
@@ -105,12 +99,7 @@ def __init__(
10599 validation data
106100 test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
107101 test data
108- <<<<<<< HEAD
109102 resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
110- =======
111- resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
112- (default=HoldoutValTypes.holdout_validation):
113- >>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
114103 strategy to split the training data.
115104 resampling_strategy_args (Optional[Dict[str, Any]]): arguments
116105 required for the chosen resampling strategy. If None, uses
@@ -132,17 +121,11 @@ def __init__(
132121 if not hasattr (train_tensors [0 ], 'shape' ):
133122 type_check (train_tensors , val_tensors )
134123 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
135- << << << < HEAD
136124 self .cross_validators : Dict [str , CrossValFunc ] = {}
137125 self .holdout_validators : Dict [str , HoldOutFunc ] = {}
138126 self .no_resampling_validators : Dict [str , NoResamplingFunc ] = {}
139127 self .random_state = np .random .RandomState (seed = seed )
140- == == == =
141- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
142- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
143- self .no_resampling_validators : Dict [str , NO_RESAMPLING_FN ] = {}
144- self .rng = np .random .RandomState (seed = seed )
145- >> >> >> > Fix mypy and flake
128+ self .no_resampling_validators : Dict [str , NoResamplingFunc ] = {}
146129 self .shuffle = shuffle
147130 self .resampling_strategy = resampling_strategy
148131 self .resampling_strategy_args = resampling_strategy_args
@@ -167,11 +150,8 @@ def __init__(
167150 # Make sure cross validation splits are created once
168151 self .cross_validators = CrossValFuncs .get_cross_validators (* CrossValTypes )
169152 self .holdout_validators = HoldOutFuncs .get_holdout_validators (* HoldoutValTypes )
170- < << << << HEAD
153+
171154 self .no_resampling_validators = NoResamplingFuncs .get_no_resampling_validators (* NoResamplingStrategyTypes )
172- == == == =
173- self .no_resampling_validators = get_no_resampling_validators (* NoResamplingStrategyTypes )
174- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
175155
176156 self .splits = self .get_splits_from_resampling_strategy ()
177157
@@ -272,12 +252,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[
272252 )
273253 )
274254 elif isinstance (self .resampling_strategy , NoResamplingStrategyTypes ):
275- << << << < HEAD
276255 splits .append ((self .no_resampling_validators [self .resampling_strategy .name ](self .random_state ,
277256 self ._get_indices ()), None ))
278- == == == =
279- splits .append ((self .no_resampling_validators [self .resampling_strategy .name ](self ._get_indices ()), None ))
280- >> > >> > > Create fit evaluator , no resampling strategy and fix bug for test statistics
281257 else :
282258 raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
283259 return splits
@@ -349,11 +325,7 @@ def create_holdout_val_split(
349325 self .random_state , val_share , self ._get_indices (), ** kwargs )
350326 return train , val
351327
352- << << < << HEAD
353328 def get_dataset (self , split_id : int , train : bool ) -> Dataset :
354- == == == =
355- def get_dataset_for_training (self , split_id : int , train : bool ) - > Dataset :
356- >> >> >> > Create fit evaluator , no resampling strategy and fix bug for test statistics
357329 """
358330 The above split methods employ the Subset to internally subsample the whole dataset.
359331
@@ -368,7 +340,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
368340 Dataset: the reduced dataset to be used for testing
369341 """
370342 # Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
371- << << < << HEAD
372343 if split_id >= len (self .splits ): # old version: split_id > len(self.splits)
373344 raise IndexError (f"self.splits index out of range, got split_id={ split_id } "
374345 f" (>= num_splits={ len (self .splits )} )" )
@@ -377,9 +348,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
377348 raise ValueError ("Specified fold (or subset) does not exist" )
378349
379350 return TransformSubset (self , indices , train = train )
380- == == == =
381- return TransformSubset (self , self .splits [split_id ][0 ], train = train )
382- >> >> > >> Create fit evaluator , no resampling strategy and fix bug for test statistics
383351
384352 def replace_data (self , X_train : BaseDatasetInputType ,
385353 X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
0 commit comments