Skip to content

Commit fae72a4

Browse files
authored
Refactoring base dataset splitting functions (#106)
* [Fork from #105] Made CrossValFuncs and HoldOutFuncs class to group the functions * Modified time_series_dataset.py to be compatible with resampling_strategy.py * [fix]: back to the renamed version of CROSS_VAL_FN from temporal SplitFunc typing. * fixed flake8 issues in three files * fixed the flake8 issues * [refactor] Address the francisco's comments * [refactor] Adress the francisco's comments * [refactor] Address the doc-string issue in TransformSubset class * [fix] Address flake8 issues * [fix] Fix flake8 issue * [fix] Fix mypy issues raised by github check * [fix] Fix a mypy issue * [fix] Fix a contradiction in holdout_stratified_validation Since stratified splitting requires to shuffle by default and it raises error in the github check, I fixed this issue. * [fix] Address the francisco's review * [fix] Fix a mypy issue tabular_dataset.py * [fix] Address the francisco's comment about the self.dataset_name Since we would to use the dataset name which does not have any name, I decided to get self.dataset_name back to Optional[str]. * [fix] Fix mypy issues
1 parent a4e08e2 commit fae72a4

File tree

5 files changed

+240
-180
lines changed

5 files changed

+240
-180
lines changed

autoPyTorch/api/base_task.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import time
1111
import typing
1212
import unittest.mock
13-
import uuid
1413
import warnings
1514
from abc import abstractmethod
1615
from typing import Any, Callable, Dict, List, Optional, Union, cast
@@ -782,13 +781,15 @@ def _search(
782781
":{}".format(self.task_type, dataset.task_type))
783782

784783
# Initialise information needed for the experiment
785-
experiment_task_name = 'runSearch'
784+
experiment_task_name: str = 'runSearch'
786785
dataset_requirements = get_dataset_requirements(
787786
info=self._get_required_dataset_properties(dataset))
788787
self._dataset_requirements = dataset_requirements
789788
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
790789
self._stopwatch.start_task(experiment_task_name)
791790
self.dataset_name = dataset.dataset_name
791+
assert self.dataset_name is not None
792+
792793
if self._logger is None:
793794
self._logger = self._get_logger(self.dataset_name)
794795
self._all_supported_metrics = all_supported_metrics
@@ -897,7 +898,7 @@ def _search(
897898
start_time=time.time(),
898899
time_left_for_ensembles=time_left_for_ensembles,
899900
backend=copy.deepcopy(self._backend),
900-
dataset_name=dataset.dataset_name,
901+
dataset_name=str(dataset.dataset_name),
901902
output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
902903
task_type=STRING_TO_TASK_TYPES[self.task_type],
903904
metrics=[self._metric],
@@ -916,7 +917,7 @@ def _search(
916917
self._stopwatch.stop_task(ensemble_task_name)
917918

918919
# ==> Run SMAC
919-
smac_task_name = 'runSMAC'
920+
smac_task_name: str = 'runSMAC'
920921
self._stopwatch.start_task(smac_task_name)
921922
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
922923
time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
@@ -928,7 +929,7 @@ def _search(
928929

929930
_proc_smac = AutoMLSMBO(
930931
config_space=self.search_space,
931-
dataset_name=dataset.dataset_name,
932+
dataset_name=str(dataset.dataset_name),
932933
backend=self._backend,
933934
total_walltime_limit=total_walltime_limit,
934935
func_eval_time_limit_secs=func_eval_time_limit_secs,
@@ -1035,11 +1036,11 @@ def refit(
10351036
Returns:
10361037
self
10371038
"""
1038-
if self.dataset_name is None:
1039-
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
1039+
1040+
self.dataset_name = dataset.dataset_name
10401041

10411042
if self._logger is None:
1042-
self._logger = self._get_logger(self.dataset_name)
1043+
self._logger = self._get_logger(str(self.dataset_name))
10431044

10441045
dataset_requirements = get_dataset_requirements(
10451046
info=self._get_required_dataset_properties(dataset))
@@ -1105,11 +1106,10 @@ def fit(self,
11051106
Returns:
11061107
(BasePipeline): fitted pipeline
11071108
"""
1108-
if self.dataset_name is None:
1109-
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
1109+
self.dataset_name = dataset.dataset_name
11101110

11111111
if self._logger is None:
1112-
self._logger = self._get_logger(self.dataset_name)
1112+
self._logger = self._get_logger(str(self.dataset_name))
11131113

11141114
# get dataset properties
11151115
dataset_requirements = get_dataset_requirements(

autoPyTorch/datasets/base_dataset.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import uuid
13
from abc import ABCMeta
24
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
35

@@ -13,18 +15,17 @@
1315

1416
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
1517
from autoPyTorch.datasets.resampling_strategy import (
16-
CROSS_VAL_FN,
18+
CrossValFunc,
19+
CrossValFuncs,
1720
CrossValTypes,
1821
DEFAULT_RESAMPLING_PARAMETERS,
19-
HOLDOUT_FN,
20-
HoldoutValTypes,
21-
get_cross_validators,
22-
get_holdout_validators,
23-
is_stratified,
22+
HoldOutFunc,
23+
HoldOutFuncs,
24+
HoldoutValTypes
2425
)
25-
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
26+
from autoPyTorch.utils.common import FitRequirement
2627

27-
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
28+
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
2829

2930

3031
def check_valid_data(data: Any) -> None:
@@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
3334
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')
3435

3536

36-
def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
37+
def type_check(train_tensors: BaseDatasetInputType,
38+
val_tensors: Optional[BaseDatasetInputType] = None) -> None:
3739
"""To avoid unexpected behavior, we use loops over indices."""
3840
for i in range(len(train_tensors)):
3941
check_valid_data(train_tensors[i])
@@ -49,8 +51,8 @@ class TransformSubset(Subset):
4951
we require different transformation for each data point.
5052
This class helps to take the subset of the dataset
5153
with either training or validation transformation.
52-
53-
We achieve so by adding a train flag to the pytorch subset
54+
The TransformSubset allows to add train flags
55+
while indexing the main dataset towards this goal.
5456
5557
Attributes:
5658
dataset (BaseDataset/Dataset): Dataset to sample the subset
@@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
7173
class BaseDataset(Dataset, metaclass=ABCMeta):
7274
def __init__(
7375
self,
74-
train_tensors: BaseDatasetType,
76+
train_tensors: BaseDatasetInputType,
7577
dataset_name: Optional[str] = None,
76-
val_tensors: Optional[BaseDatasetType] = None,
77-
test_tensors: Optional[BaseDatasetType] = None,
78+
val_tensors: Optional[BaseDatasetInputType] = None,
79+
test_tensors: Optional[BaseDatasetInputType] = None,
7880
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
7981
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8082
shuffle: Optional[bool] = True,
@@ -106,14 +108,16 @@ def __init__(
106108
val_transforms (Optional[torchvision.transforms.Compose]):
107109
Additional Transforms to be applied to the validation/test data
108110
"""
109-
self.dataset_name = dataset_name if dataset_name is not None \
110-
else hash_array_or_matrix(train_tensors[0])
111+
self.dataset_name = dataset_name
112+
113+
if self.dataset_name is None:
114+
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
111115

112116
if not hasattr(train_tensors[0], 'shape'):
113117
type_check(train_tensors, val_tensors)
114118
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
115-
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
116-
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
119+
self.cross_validators: Dict[str, CrossValFunc] = {}
120+
self.holdout_validators: Dict[str, HoldOutFunc] = {}
117121
self.rng = np.random.RandomState(seed=seed)
118122
self.shuffle = shuffle
119123
self.resampling_strategy = resampling_strategy
@@ -134,8 +138,8 @@ def __init__(
134138
self.is_small_preprocess = True
135139

136140
# Make sure cross validation splits are created once
137-
self.cross_validators = get_cross_validators(*CrossValTypes)
138-
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
141+
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
142+
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
139143
self.splits = self.get_splits_from_resampling_strategy()
140144

141145
# We also need to be able to transform the data, be it for pre-processing
@@ -263,7 +267,7 @@ def create_cross_val_splits(
263267
if not isinstance(cross_val_type, CrossValTypes):
264268
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
265269
kwargs = {}
266-
if is_stratified(cross_val_type):
270+
if cross_val_type.is_stratified():
267271
# we need additional information about the data for stratification
268272
kwargs["stratify"] = self.train_tensors[-1]
269273
splits = self.cross_validators[cross_val_type.name](
@@ -298,7 +302,7 @@ def create_holdout_val_split(
298302
if not isinstance(holdout_val_type, HoldoutValTypes):
299303
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
300304
kwargs = {}
301-
if is_stratified(holdout_val_type):
305+
if holdout_val_type.is_stratified():
302306
# we need additional information about the data for stratification
303307
kwargs["stratify"] = self.train_tensors[-1]
304308
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
@@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
321325
return (TransformSubset(self, self.splits[split_id][0], train=True),
322326
TransformSubset(self, self.splits[split_id][1], train=False))
323327

324-
def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
328+
def replace_data(self, X_train: BaseDatasetInputType,
329+
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
325330
"""
326331
To speed up the training of small dataset, early pre-processing of the data
327332
can be made on the fly by the pipeline.

0 commit comments

Comments
 (0)