Skip to content

Commit

Permalink
reflected the comments from ravin
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Mar 4, 2021
1 parent 2cbc1ce commit 6635b3b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
17 changes: 8 additions & 9 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix

BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]


def check_valid_data(data: Any) -> None:
Expand All @@ -32,10 +32,9 @@ def check_valid_data(data: Any) -> None:
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')


def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
"""To avoid unexpected behavior, we use loops over indices."""
for i in range(len(train_tensors)):
check_valid_data(train_tensors[i])
def type_check(train_tensors: BaseDatasetInputType, val_tensors: Optional[BaseDatasetInputType] = None) -> None:
for train_tensor in train_tensors:
check_valid_data(train_tensor)
if val_tensors is not None:
for i in range(len(val_tensors)):
check_valid_data(val_tensors[i])
Expand Down Expand Up @@ -63,10 +62,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(
self,
train_tensors: BaseDatasetType,
train_tensors: BaseDatasetInputType,
dataset_name: Optional[str] = None,
val_tensors: Optional[BaseDatasetType] = None,
test_tensors: Optional[BaseDatasetType] = None,
val_tensors: Optional[BaseDatasetInputType] = None,
test_tensors: Optional[BaseDatasetInputType] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
Expand Down Expand Up @@ -313,7 +312,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
return (TransformSubset(self, self.splits[split_id][0], train=True),
TransformSubset(self, self.splits[split_id][1], train=False))

def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
def replace_data(self, X_train: BaseDatasetInputType, X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
"""
To speed up the training of small dataset, early pre-processing of the data
can be made on the fly by the pipeline.
Expand Down
32 changes: 21 additions & 11 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,12 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
"""
Standard k fold cross validation.
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
Args:
indices (np.ndarray): array of indices to be split
num_splits (int): number of cross validation splits
Returns:
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
"""
cv = KFold(n_splits=num_splits)
splits = list(cv.split(indices))
Expand All @@ -163,14 +166,21 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Returns train and validation indices respecting the temporal ordering of the data.
Dummy example: [0, 1, 2, 3] with 3 folds yields
[0] [1]
[0, 1] [2]
[0, 1, 2] [3]
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
Args:
indices (np.ndarray): array of indices to be split
num_splits (int): number of cross validation splits
Returns:
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
Examples:
>>> indices = np.array([0, 1, 2, 3])
>>> CrossValFuncs.time_series_cross_validation(3, indices)
[([0], [1]),
([0, 1], [2]),
([0, 1, 2], [3])]
"""
cv = TimeSeriesSplit(n_splits=num_splits)
splits = list(cv.split(indices))
Expand Down

0 comments on commit 6635b3b

Please sign in to comment.