Skip to content

Commit c45087b

Browse files
[FIX] fix seed in splits
1 parent bbe067d commit c45087b

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
119119
self.cross_validators: Dict[str, CrossValFunc] = {}
120120
self.holdout_validators: Dict[str, HoldOutFunc] = {}
121-
self.rng = np.random.RandomState(seed=seed)
121+
self.random_state = np.random.RandomState(seed=seed)
122122
self.shuffle = shuffle
123123
self.resampling_strategy = resampling_strategy
124124
self.resampling_strategy_args = resampling_strategy_args
@@ -205,7 +205,7 @@ def __len__(self) -> int:
205205
return self.train_tensors[0].shape[0]
206206

207207
def _get_indices(self) -> np.ndarray:
208-
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))
208+
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
209209

210210
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
211211
"""
@@ -271,7 +271,7 @@ def create_cross_val_splits(
271271
# we need additional information about the data for stratification
272272
kwargs["stratify"] = self.train_tensors[-1]
273273
splits = self.cross_validators[cross_val_type.name](
274-
num_splits, self._get_indices(), **kwargs)
274+
self.random_state, num_splits, self._get_indices(), **kwargs)
275275
return splits
276276

277277
def create_holdout_val_split(
@@ -305,7 +305,8 @@ def create_holdout_val_split(
305305
if holdout_val_type.is_stratified():
306306
# we need additional information about the data for stratification
307307
kwargs["stratify"] = self.train_tensors[-1]
308-
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
308+
train, val = self.holdout_validators[holdout_val_type.name](
309+
self.random_state, val_share, self._get_indices(), **kwargs)
309310
return train, val
310311

311312
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
1919
class CrossValFunc(Protocol):
2020
def __call__(self,
21+
random_state: np.random.RandomState,
2122
num_splits: int,
2223
indices: np.ndarray,
2324
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
2425
...
2526

2627

2728
class HoldOutFunc(Protocol):
28-
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
29+
def __call__(self, random_state: np.random.RandomState, val_share: float,
30+
indices: np.ndarray, stratify: Optional[Any]
2931
) -> Tuple[np.ndarray, np.ndarray]:
3032
...
3133

@@ -85,35 +87,42 @@ def is_stratified(self) -> bool:
8587
'val_share': 0.33,
8688
},
8789
CrossValTypes.k_fold_cross_validation: {
88-
'num_splits': 3,
90+
'num_splits': 5,
8991
},
9092
CrossValTypes.stratified_k_fold_cross_validation: {
91-
'num_splits': 3,
93+
'num_splits': 5,
9294
},
9395
CrossValTypes.shuffle_split_cross_validation: {
94-
'num_splits': 3,
96+
'num_splits': 5,
9597
},
9698
CrossValTypes.time_series_cross_validation: {
97-
'num_splits': 3,
99+
'num_splits': 5,
98100
},
99101
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
100102

101103

102104
class HoldOutFuncs():
103105
@staticmethod
104-
def holdout_validation(val_share: float,
106+
def holdout_validation(random_state: np.random.RandomState,
107+
val_share: float,
105108
indices: np.ndarray,
106109
**kwargs: Any
107110
) -> Tuple[np.ndarray, np.ndarray]:
108-
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
111+
shuffle = kwargs.get('shuffle', True)
112+
train, val = train_test_split(indices, test_size=val_share,
113+
shuffle=shuffle,
114+
random_state=random_state if shuffle else None,
115+
)
109116
return train, val
110117

111118
@staticmethod
112-
def stratified_holdout_validation(val_share: float,
119+
def stratified_holdout_validation(random_state: np.random.RandomState,
120+
val_share: float,
113121
indices: np.ndarray,
114122
**kwargs: Any
115123
) -> Tuple[np.ndarray, np.ndarray]:
116-
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"])
124+
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"],
125+
random_state=random_state)
117126
return train, val
118127

119128
@classmethod
@@ -128,34 +137,38 @@ def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str
128137

129138
class CrossValFuncs():
130139
@staticmethod
131-
def shuffle_split_cross_validation(num_splits: int,
140+
def shuffle_split_cross_validation(random_state: np.random.RandomState,
141+
num_splits: int,
132142
indices: np.ndarray,
133143
**kwargs: Any
134144
) -> List[Tuple[np.ndarray, np.ndarray]]:
135-
cv = ShuffleSplit(n_splits=num_splits)
145+
cv = ShuffleSplit(n_splits=num_splits, random_state=random_state)
136146
splits = list(cv.split(indices))
137147
return splits
138148

139149
@staticmethod
140-
def stratified_shuffle_split_cross_validation(num_splits: int,
150+
def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState,
151+
num_splits: int,
141152
indices: np.ndarray,
142153
**kwargs: Any
143154
) -> List[Tuple[np.ndarray, np.ndarray]]:
144-
cv = StratifiedShuffleSplit(n_splits=num_splits)
155+
cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state)
145156
splits = list(cv.split(indices, kwargs["stratify"]))
146157
return splits
147158

148159
@staticmethod
149-
def stratified_k_fold_cross_validation(num_splits: int,
160+
def stratified_k_fold_cross_validation(random_state: np.random.RandomState,
161+
num_splits: int,
150162
indices: np.ndarray,
151163
**kwargs: Any
152164
) -> List[Tuple[np.ndarray, np.ndarray]]:
153-
cv = StratifiedKFold(n_splits=num_splits)
165+
cv = StratifiedKFold(n_splits=num_splits, random_state=random_state)
154166
splits = list(cv.split(indices, kwargs["stratify"]))
155167
return splits
156168

157169
@staticmethod
158-
def k_fold_cross_validation(num_splits: int,
170+
def k_fold_cross_validation(random_state: np.random.RandomState,
171+
num_splits: int,
159172
indices: np.ndarray,
160173
**kwargs: Any
161174
) -> List[Tuple[np.ndarray, np.ndarray]]:
@@ -169,12 +182,14 @@ def k_fold_cross_validation(num_splits: int,
169182
Returns:
170183
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
171184
"""
172-
cv = KFold(n_splits=num_splits)
185+
shuffle = kwargs.get('shuffle', True)
186+
cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle)
173187
splits = list(cv.split(indices))
174188
return splits
175189

176190
@staticmethod
177-
def time_series_cross_validation(num_splits: int,
191+
def time_series_cross_validation(random_state: np.random.RandomState,
192+
num_splits: int,
178193
indices: np.ndarray,
179194
**kwargs: Any
180195
) -> List[Tuple[np.ndarray, np.ndarray]]:
@@ -196,7 +211,7 @@ def time_series_cross_validation(num_splits: int,
196211
([0, 1, 2], [3])]
197212
198213
"""
199-
cv = TimeSeriesSplit(n_splits=num_splits)
214+
cv = TimeSeriesSplit(n_splits=num_splits, random_state=random_state)
200215
splits = list(cv.split(indices))
201216
return splits
202217

0 commit comments

Comments
 (0)