-
Notifications
You must be signed in to change notification settings - Fork 299
Refactoring base dataset splitting functions #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f204670
c22ec53
bfd1a83
4be32b3
5284a84
dbe7ca6
594bfaa
a8981a9
e27b46f
6b28b42
96e2bb0
0629efe
d20c1b5
244a23e
988a3a2
70d7d60
1c6dc23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import os | ||
import uuid | ||
from abc import ABCMeta | ||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast | ||
|
||
|
@@ -13,18 +15,17 @@ | |
|
||
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES | ||
from autoPyTorch.datasets.resampling_strategy import ( | ||
CROSS_VAL_FN, | ||
CrossValFunc, | ||
CrossValFuncs, | ||
nabenabe0928 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CrossValTypes, | ||
DEFAULT_RESAMPLING_PARAMETERS, | ||
HOLDOUT_FN, | ||
HoldoutValTypes, | ||
get_cross_validators, | ||
get_holdout_validators, | ||
is_stratified, | ||
HoldOutFunc, | ||
HoldOutFuncs, | ||
nabenabe0928 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
HoldoutValTypes | ||
) | ||
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix | ||
from autoPyTorch.utils.common import FitRequirement | ||
|
||
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] | ||
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset] | ||
|
||
|
||
def check_valid_data(data: Any) -> None: | ||
|
@@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None: | |
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.') | ||
|
||
|
||
def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None: | ||
def type_check(train_tensors: BaseDatasetInputType, | ||
val_tensors: Optional[BaseDatasetInputType] = None) -> None: | ||
"""To avoid unexpected behavior, we use loops over indices.""" | ||
for i in range(len(train_tensors)): | ||
check_valid_data(train_tensors[i]) | ||
|
@@ -49,8 +51,8 @@ class TransformSubset(Subset): | |
we require different transformation for each data point. | ||
This class helps to take the subset of the dataset | ||
with either training or validation transformation. | ||
|
||
We achieve so by adding a train flag to the pytorch subset | ||
The TransformSubset allows to add train flags | ||
while indexing the main dataset towards this goal. | ||
|
||
Attributes: | ||
dataset (BaseDataset/Dataset): Dataset to sample the subset | ||
|
@@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray: | |
class BaseDataset(Dataset, metaclass=ABCMeta): | ||
def __init__( | ||
self, | ||
train_tensors: BaseDatasetType, | ||
train_tensors: BaseDatasetInputType, | ||
dataset_name: Optional[str] = None, | ||
val_tensors: Optional[BaseDatasetType] = None, | ||
test_tensors: Optional[BaseDatasetType] = None, | ||
val_tensors: Optional[BaseDatasetInputType] = None, | ||
test_tensors: Optional[BaseDatasetInputType] = None, | ||
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, | ||
resampling_strategy_args: Optional[Dict[str, Any]] = None, | ||
shuffle: Optional[bool] = True, | ||
|
@@ -106,14 +108,16 @@ def __init__( | |
val_transforms (Optional[torchvision.transforms.Compose]): | ||
Additional Transforms to be applied to the validation/test data | ||
""" | ||
self.dataset_name = dataset_name if dataset_name is not None \ | ||
else hash_array_or_matrix(train_tensors[0]) | ||
self.dataset_name = dataset_name | ||
|
||
if self.dataset_name is None: | ||
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then if you think that it should not be required, then maybe do it as:
What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, actually I am also thinking about it this way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, wait. But when I think about the case where we would like to use totally new datasets which do not have any name, probably we would like to choose our own name. In this sense, it is better to get back to |
||
|
||
if not hasattr(train_tensors[0], 'shape'): | ||
type_check(train_tensors, val_tensors) | ||
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors | ||
self.cross_validators: Dict[str, CROSS_VAL_FN] = {} | ||
self.holdout_validators: Dict[str, HOLDOUT_FN] = {} | ||
ravinkohli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.cross_validators: Dict[str, CrossValFunc] = {} | ||
self.holdout_validators: Dict[str, HoldOutFunc] = {} | ||
self.rng = np.random.RandomState(seed=seed) | ||
self.shuffle = shuffle | ||
self.resampling_strategy = resampling_strategy | ||
|
@@ -134,8 +138,8 @@ def __init__( | |
self.is_small_preprocess = True | ||
|
||
# Make sure cross validation splits are created once | ||
self.cross_validators = get_cross_validators(*CrossValTypes) | ||
self.holdout_validators = get_holdout_validators(*HoldoutValTypes) | ||
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) | ||
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) | ||
self.splits = self.get_splits_from_resampling_strategy() | ||
|
||
# We also need to be able to transform the data, be it for pre-processing | ||
|
@@ -263,7 +267,7 @@ def create_cross_val_splits( | |
if not isinstance(cross_val_type, CrossValTypes): | ||
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') | ||
kwargs = {} | ||
if is_stratified(cross_val_type): | ||
if cross_val_type.is_stratified(): | ||
# we need additional information about the data for stratification | ||
kwargs["stratify"] = self.train_tensors[-1] | ||
splits = self.cross_validators[cross_val_type.name]( | ||
|
@@ -298,7 +302,7 @@ def create_holdout_val_split( | |
if not isinstance(holdout_val_type, HoldoutValTypes): | ||
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') | ||
kwargs = {} | ||
if is_stratified(holdout_val_type): | ||
if holdout_val_type.is_stratified(): | ||
# we need additional information about the data for stratification | ||
kwargs["stratify"] = self.train_tensors[-1] | ||
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) | ||
|
@@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: | |
return (TransformSubset(self, self.splits[split_id][0], train=True), | ||
TransformSubset(self, self.splits[split_id][1], train=False)) | ||
|
||
def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset': | ||
def replace_data(self, X_train: BaseDatasetInputType, | ||
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset': | ||
""" | ||
To speed up the training of small dataset, early pre-processing of the data | ||
can be made on the fly by the pipeline. | ||
|
Uh oh!
There was an error while loading. Please reload this page.