8
8
from autoPyTorch .datasets .resampling_strategy import (
9
9
CrossValTypes ,
10
10
HoldoutValTypes ,
11
- get_cross_validators ,
12
- get_holdout_validators
11
+ CrossValFuncs ,
12
+ HoldOutFuncs
13
13
)
14
14
15
15
TIME_SERIES_FORECASTING_INPUT = Tuple [np .ndarray , np .ndarray ] # currently only numpy arrays are supported
@@ -60,8 +60,8 @@ def __init__(self,
60
60
train_transforms = train_transforms ,
61
61
val_transforms = val_transforms ,
62
62
)
63
- self .cross_validators = get_cross_validators (CrossValTypes .time_series_cross_validation )
64
- self .holdout_validators = get_holdout_validators (HoldoutValTypes .holdout_validation )
63
+ self .cross_validators = CrossValFuncs . get_cross_validators (CrossValTypes .time_series_cross_validation )
64
+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (HoldoutValTypes .holdout_validation )
65
65
66
66
67
67
def _check_time_series_forecasting_inputs (target_variables : Tuple [int ],
@@ -117,13 +117,13 @@ def __init__(self,
117
117
val = val ,
118
118
task_type = "time_series_classification" )
119
119
super ().__init__ (train_tensors = train , val_tensors = val , shuffle = True )
120
- self .cross_validators = get_cross_validators (
120
+ self .cross_validators = CrossValFuncs . get_cross_validators (
121
121
CrossValTypes .stratified_k_fold_cross_validation ,
122
122
CrossValTypes .k_fold_cross_validation ,
123
123
CrossValTypes .shuffle_split_cross_validation ,
124
124
CrossValTypes .stratified_shuffle_split_cross_validation
125
125
)
126
- self .holdout_validators = get_holdout_validators (
126
+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (
127
127
HoldoutValTypes .holdout_validation ,
128
128
HoldoutValTypes .stratified_holdout_validation
129
129
)
@@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
135
135
val = val ,
136
136
task_type = "time_series_regression" )
137
137
super ().__init__ (train_tensors = train , val_tensors = val , shuffle = True )
138
- self .cross_validators = get_cross_validators (
138
+ self .cross_validators = CrossValFuncs . get_cross_validators (
139
139
CrossValTypes .k_fold_cross_validation ,
140
140
CrossValTypes .shuffle_split_cross_validation
141
141
)
142
- self .holdout_validators = get_holdout_validators (
142
+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (
143
143
HoldoutValTypes .holdout_validation
144
144
)
145
145
0 commit comments