Skip to content

Commit f2fb6d4

Browse files
committed
Create fit evaluator, no resampling strategy and fix bug for test statistics
Fix mypy and flake Fix check for X_test while making test data loader fix bug in lookahead hyperparameters where lookahead was repeated for the hyperparameter name Make passing tests in api easier Fix bug in trainer weighted loss code for regression
1 parent 2b1b663 commit f2fb6d4

21 files changed

+773
-91
lines changed

autoPyTorch/api/base_task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from autoPyTorch.data.base_validator import BaseInputValidator
4242
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
43-
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
43+
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes
4444
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
4545
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
4646
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
@@ -172,7 +172,9 @@ def __init__(
172172
include_components: Optional[Dict[str, Any]] = None,
173173
exclude_components: Optional[Dict[str, Any]] = None,
174174
backend: Optional[Backend] = None,
175-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
175+
resampling_strategy: Union[CrossValTypes,
176+
HoldoutValTypes,
177+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
176178
resampling_strategy_args: Optional[Dict[str, Any]] = None,
177179
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
178180
task_type: Optional[str] = None
@@ -1390,9 +1392,6 @@ def fit_pipeline(
13901392
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
13911393
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue, BaseDataset]:
13921394
"""
1393-
Fit a pipeline on the given task for the budget.
1394-
A pipeline configuration can be specified if None,
1395-
uses default
13961395
Fit uses the estimator pipeline_config attribute, which the user
13971396
can interact via the get_pipeline_config()/set_pipeline_config()
13981397
methods.
@@ -1494,6 +1493,7 @@ def fit_pipeline(
14941493
(BaseDataset):
14951494
Dataset created from the given tensors
14961495
"""
1496+
self.dataset_name = dataset.dataset_name
14971497

14981498
if dataset is None:
14991499
if (

autoPyTorch/api/tabular_classification.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from autoPyTorch.datasets.resampling_strategy import (
1616
CrossValTypes,
1717
HoldoutValTypes,
18+
NoResamplingStrategyTypes
1819
)
1920
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2021
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
@@ -81,9 +82,11 @@ def __init__(
8182
output_directory: Optional[str] = None,
8283
delete_tmp_folder_after_terminate: bool = True,
8384
delete_output_folder_after_terminate: bool = True,
84-
include_components: Optional[Dict[str, Any]] = None,
85-
exclude_components: Optional[Dict[str, Any]] = None,
86-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
85+
include_components: Optional[Dict] = None,
86+
exclude_components: Optional[Dict] = None,
87+
resampling_strategy: Union[CrossValTypes,
88+
HoldoutValTypes,
89+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
8790
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8891
backend: Optional[Backend] = None,
8992
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None

autoPyTorch/api/tabular_regression.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from autoPyTorch.datasets.resampling_strategy import (
1616
CrossValTypes,
1717
HoldoutValTypes,
18+
NoResamplingStrategyTypes
1819
)
1920
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2021
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
@@ -82,9 +83,11 @@ def __init__(
8283
output_directory: Optional[str] = None,
8384
delete_tmp_folder_after_terminate: bool = True,
8485
delete_output_folder_after_terminate: bool = True,
85-
include_components: Optional[Dict[str, Any]] = None,
86-
exclude_components: Optional[Dict[str, Any]] = None,
87-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
86+
include_components: Optional[Dict] = None,
87+
exclude_components: Optional[Dict] = None,
88+
resampling_strategy:Union[CrossValTypes,
89+
HoldoutValTypes,
90+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
8891
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8992
backend: Optional[Backend] = None,
9093
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None

autoPyTorch/datasets/base_dataset.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
DEFAULT_RESAMPLING_PARAMETERS,
2222
HoldOutFunc,
2323
HoldOutFuncs,
24-
HoldoutValTypes
24+
HoldoutValTypes,
25+
get_no_resampling_validators,
26+
NoResamplingStrategyTypes,
27+
NO_RESAMPLING_FN
2528
)
2629
from autoPyTorch.utils.common import FitRequirement
2730

@@ -78,7 +81,9 @@ def __init__(
7881
dataset_name: Optional[str] = None,
7982
val_tensors: Optional[BaseDatasetInputType] = None,
8083
test_tensors: Optional[BaseDatasetInputType] = None,
81-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
84+
resampling_strategy: Union[CrossValTypes,
85+
HoldoutValTypes,
86+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
8287
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8388
shuffle: Optional[bool] = True,
8489
seed: Optional[int] = 42,
@@ -95,7 +100,7 @@ def __init__(
95100
validation data
96101
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
97102
test data
98-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
103+
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
99104
(default=HoldoutValTypes.holdout_validation):
100105
strategy to split the training data.
101106
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
@@ -117,9 +122,16 @@ def __init__(
117122
if not hasattr(train_tensors[0], 'shape'):
118123
type_check(train_tensors, val_tensors)
119124
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
125+
<<<<<<< HEAD
120126
self.cross_validators: Dict[str, CrossValFunc] = {}
121127
self.holdout_validators: Dict[str, HoldOutFunc] = {}
122128
self.random_state = np.random.RandomState(seed=seed)
129+
=======
130+
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
131+
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
132+
self.no_resampling_validators: Dict[str, NO_RESAMPLING_FN] = {}
133+
self.rng = np.random.RandomState(seed=seed)
134+
>>>>>>> Fix mypy and flake
123135
self.shuffle = shuffle
124136
self.resampling_strategy = resampling_strategy
125137
self.resampling_strategy_args = resampling_strategy_args
@@ -144,6 +156,8 @@ def __init__(
144156
# Make sure cross validation splits are created once
145157
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
146158
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
159+
self.no_resampling_validators = get_no_resampling_validators(*NoResamplingStrategyTypes)
160+
147161
self.splits = self.get_splits_from_resampling_strategy()
148162

149163
# We also need to be able to transform the data, be it for pre-processing
@@ -211,7 +225,7 @@ def __len__(self) -> int:
211225
def _get_indices(self) -> np.ndarray:
212226
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
213227

214-
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
228+
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[List[int]]]]:
215229
"""
216230
Creates a set of splits based on a resampling strategy provided
217231
@@ -242,6 +256,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
242256
num_splits=cast(int, num_splits),
243257
)
244258
)
259+
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
260+
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self._get_indices()), None))
245261
else:
246262
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
247263
return splits
@@ -313,7 +329,7 @@ def create_holdout_val_split(
313329
self.random_state, val_share, self._get_indices(), **kwargs)
314330
return train, val
315331

316-
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
332+
def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
317333
"""
318334
The above split methods employ the Subset to internally subsample the whole dataset.
319335
@@ -327,8 +343,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
327343
Dataset: the reduced dataset to be used for testing
328344
"""
329345
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
330-
return (TransformSubset(self, self.splits[split_id][0], train=True),
331-
TransformSubset(self, self.splits[split_id][1], train=False))
346+
return TransformSubset(self, self.splits[split_id][0], train=train)
332347

333348
def replace_data(self, X_train: BaseDatasetInputType,
334349
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':

autoPyTorch/datasets/image_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from autoPyTorch.datasets.resampling_strategy import (
2525
CrossValTypes,
2626
HoldoutValTypes,
27+
NoResamplingStrategyTypes
2728
)
2829

2930
IMAGE_DATASET_INPUT = Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]
@@ -39,7 +40,7 @@ class ImageDataset(BaseDataset):
3940
validation data
4041
test (Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]):
4142
testing data
42-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
43+
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
4344
(default=HoldoutValTypes.holdout_validation):
4445
strategy to split the training data.
4546
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
@@ -57,7 +58,9 @@ def __init__(self,
5758
train: IMAGE_DATASET_INPUT,
5859
val: Optional[IMAGE_DATASET_INPUT] = None,
5960
test: Optional[IMAGE_DATASET_INPUT] = None,
60-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
61+
resampling_strategy: Union[CrossValTypes,
62+
HoldoutValTypes,
63+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
6164
resampling_strategy_args: Optional[Dict[str, Any]] = None,
6265
shuffle: Optional[bool] = True,
6366
seed: Optional[int] = 42,

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3232
...
3333

3434

35+
class NO_RESAMPLING_FN(Protocol):
36+
def __call__(self, indices: np.ndarray) -> np.ndarray:
37+
...
38+
39+
3540
class CrossValTypes(IntEnum):
3641
"""The type of cross validation
3742
@@ -76,8 +81,14 @@ def is_stratified(self) -> bool:
7681
return getattr(self, self.name) in stratified
7782

7883

84+
class NoResamplingStrategyTypes(IntEnum):
85+
no_resampling = 8
86+
shuffle_no_resampling = 9
87+
88+
7989
# TODO: replace it with another way
80-
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]
90+
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
91+
8192

8293
DEFAULT_RESAMPLING_PARAMETERS: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] = {
8394
HoldoutValTypes.holdout_validation: {
@@ -98,7 +109,13 @@ def is_stratified(self) -> bool:
98109
CrossValTypes.time_series_cross_validation: {
99110
'num_splits': 5,
100111
},
101-
}
112+
NoResamplingStrategyTypes.no_resampling: {
113+
'shuffle': False
114+
},
115+
NoResamplingStrategyTypes.shuffle_no_resampling: {
116+
'shuffle': True
117+
}
118+
} # type: Dict[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes], Dict[str, Any]]
102119

103120

104121
class HoldOutFuncs():
@@ -225,3 +242,55 @@ def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, Cros
225242
for cross_val_type in cross_val_types
226243
}
227244
return cross_validators
245+
246+
247+
def get_no_resampling_validators(*no_resampling: NoResamplingStrategyTypes) -> Dict[str, NO_RESAMPLING_FN]:
248+
no_resampling_strategies = {} # type: Dict[str, NO_RESAMPLING_FN]
249+
for strategy in no_resampling:
250+
no_resampling_fn = globals()[strategy.name]
251+
no_resampling_strategies[strategy.name] = no_resampling_fn
252+
return no_resampling_strategies
253+
254+
255+
def no_resampling(indices: np.ndarray) -> np.ndarray:
256+
"""
257+
Returns the indices without performing
258+
any operation on them. To be used for
259+
fitting on the whole dataset.
260+
This strategy is not compatible with
261+
HPO search.
262+
Args:
263+
indices: array of indices
264+
265+
Returns:
266+
np.ndarray: array of indices
267+
"""
268+
return indices
269+
270+
271+
def shuffle_no_resampling(indices: np.ndarray, **kwargs: Any) -> np.ndarray:
272+
"""
273+
Returns the indices after shuffling them.
274+
To be used for fitting on the whole dataset.
275+
This strategy is not compatible with HPO search.
276+
Args:
277+
indices: array of indices
278+
279+
Returns:
280+
np.ndarray: shuffled array of indices
281+
"""
282+
if 'random_state' in kwargs:
283+
if isinstance(kwargs['random_state'], np.random.RandomState):
284+
kwargs['random_state'].shuffle(indices)
285+
elif isinstance(kwargs['random_state'], int):
286+
np.random.seed(kwargs['random_state'])
287+
np.random.shuffle(indices)
288+
else:
289+
raise ValueError("Illegal value for 'random_state' entered. "
290+
"Expected it to be {} or {} but got {}".format(int,
291+
np.random.RandomState,
292+
type(kwargs['random_state'])))
293+
else:
294+
np.random.shuffle(indices)
295+
296+
return indices

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from autoPyTorch.datasets.resampling_strategy import (
2222
CrossValTypes,
2323
HoldoutValTypes,
24+
NoResamplingStrategyTypes
2425
)
2526

2627

@@ -32,7 +33,7 @@ class TabularDataset(BaseDataset):
3233
Y (Union[np.ndarray, pd.Series]): training data targets.
3334
X_test (Optional[Union[np.ndarray, pd.DataFrame]]): input testing data.
3435
Y_test (Optional[Union[np.ndarray, pd.DataFrame]]): testing data targets
35-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
36+
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
3637
(default=HoldoutValTypes.holdout_validation):
3738
strategy to split the training data.
3839
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -55,7 +56,9 @@ def __init__(self,
5556
Y: Union[np.ndarray, pd.Series],
5657
X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
5758
Y_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
58-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
59+
resampling_strategy: Union[CrossValTypes,
60+
HoldoutValTypes,
61+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
5962
resampling_strategy_args: Optional[Dict[str, Any]] = None,
6063
shuffle: Optional[bool] = True,
6164
seed: Optional[int] = 42,

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,9 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
663663
y_true, y_hat, self.task_type, metrics)
664664

665665
def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
666-
opt_pred: np.ndarray, valid_pred: Optional[np.ndarray],
667-
test_pred: Optional[np.ndarray], additional_run_info: Optional[Dict],
668-
file_output: bool, status: StatusType
666+
valid_pred: Optional[np.ndarray], test_pred: Optional[np.ndarray],
667+
additional_run_info: Optional[Dict], file_output: bool, status: StatusType,
668+
opt_pred: Optional[np.ndarray],
669669
) -> Optional[Tuple[float, float, int, Dict]]:
670670
"""This function does everything necessary after the fitting is done:
671671
@@ -707,6 +707,9 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
707707
Additional run information, like train/test loss
708708
"""
709709

710+
assert opt_pred is not None, "Cases where 'opt_pred' is None should be handled " \
711+
"specifically with special child classes"
712+
710713
self.duration = time.time() - self.starttime
711714

712715
if file_output:

0 commit comments

Comments
 (0)