Skip to content

Commit 87534fd

Browse files
committed
Modified time_series_dataset.py to be compatible with resampling_strategy.py
1 parent c6d046b commit 87534fd

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

autoPyTorch/datasets/time_series_dataset.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from autoPyTorch.datasets.resampling_strategy import (
99
CrossValTypes,
1010
HoldoutValTypes,
11-
get_cross_validators,
12-
get_holdout_validators
11+
CrossValFuncs,
12+
HoldOutFuncs
1313
)
1414

1515
TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported
@@ -60,8 +60,8 @@ def __init__(self,
6060
train_transforms=train_transforms,
6161
val_transforms=val_transforms,
6262
)
63-
self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation)
64-
self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation)
63+
self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation)
64+
self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation)
6565

6666

6767
def _check_time_series_forecasting_inputs(target_variables: Tuple[int],
@@ -117,13 +117,13 @@ def __init__(self,
117117
val=val,
118118
task_type="time_series_classification")
119119
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
120-
self.cross_validators = get_cross_validators(
120+
self.cross_validators = CrossValFuncs.get_cross_validators(
121121
CrossValTypes.stratified_k_fold_cross_validation,
122122
CrossValTypes.k_fold_cross_validation,
123123
CrossValTypes.shuffle_split_cross_validation,
124124
CrossValTypes.stratified_shuffle_split_cross_validation
125125
)
126-
self.holdout_validators = get_holdout_validators(
126+
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
127127
HoldoutValTypes.holdout_validation,
128128
HoldoutValTypes.stratified_holdout_validation
129129
)
@@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
135135
val=val,
136136
task_type="time_series_regression")
137137
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
138-
self.cross_validators = get_cross_validators(
138+
self.cross_validators = CrossValFuncs.get_cross_validators(
139139
CrossValTypes.k_fold_cross_validation,
140140
CrossValTypes.shuffle_split_cross_validation
141141
)
142-
self.holdout_validators = get_holdout_validators(
142+
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
143143
HoldoutValTypes.holdout_validation
144144
)
145145

0 commit comments

Comments
 (0)