-
Notifications
You must be signed in to change notification settings - Fork 299
[ADD] Test evaluator #368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ravinkohli
merged 21 commits into
automl:development
from
ravinkohli:add_test_evaluator
Jan 25, 2022
Merged
[ADD] Test evaluator #368
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
2aaf453
add test evaluator
ravinkohli 8ad387f
add no resampling and other changes for test evaluator
ravinkohli e085eeb
finalise changes for test_evaluator, TODO: tests
ravinkohli 4bc6c01
add tests for new functionality
ravinkohli 40d67e5
fix flake and mypy
ravinkohli b5f8992
add documentation for the evaluator
ravinkohli e898496
add NoResampling to fit_pipeline
ravinkohli 01c4a9a
raise error when trying to construct ensemble with noresampling
ravinkohli 6620f78
fix tests
ravinkohli d6874e1
reduce fit_pipeline accuracy check
ravinkohli c7a723e
Apply suggestions from code review
ravinkohli 07980cd
address comments from shuhei
ravinkohli 06cdc22
fix bug in base data loader
ravinkohli be73604
fix bug in data loader for val set
ravinkohli ff56311
fix bugs introduced in suggestions
ravinkohli 3836ddf
fix flake
ravinkohli 657c152
fix bug in test preprocessing
ravinkohli 70a17a6
fix bug in test data loader
ravinkohli 11c161c
merge tests for evaluators and change listcomp in get_best_epoch
ravinkohli 8055406
rename resampling strategies
ravinkohli c017fac
add test for get dataset
ravinkohli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,11 @@ | |
DEFAULT_RESAMPLING_PARAMETERS, | ||
HoldOutFunc, | ||
HoldOutFuncs, | ||
HoldoutValTypes | ||
HoldoutValTypes, | ||
NoResamplingFunc, | ||
NoResamplingFuncs, | ||
NoResamplingStrategyTypes, | ||
ResamplingStrategies | ||
) | ||
from autoPyTorch.utils.common import FitRequirement | ||
|
||
|
@@ -78,7 +82,7 @@ def __init__( | |
dataset_name: Optional[str] = None, | ||
val_tensors: Optional[BaseDatasetInputType] = None, | ||
test_tensors: Optional[BaseDatasetInputType] = None, | ||
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, | ||
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation, | ||
resampling_strategy_args: Optional[Dict[str, Any]] = None, | ||
shuffle: Optional[bool] = True, | ||
seed: Optional[int] = 42, | ||
|
@@ -95,8 +99,7 @@ def __init__( | |
validation data | ||
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute): | ||
test data | ||
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), | ||
(default=HoldoutValTypes.holdout_validation): | ||
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation): | ||
strategy to split the training data. | ||
resampling_strategy_args (Optional[Dict[str, Any]]): arguments | ||
required for the chosen resampling strategy. If None, uses | ||
|
@@ -109,16 +112,18 @@ def __init__( | |
val_transforms (Optional[torchvision.transforms.Compose]): | ||
Additional Transforms to be applied to the validation/test data | ||
""" | ||
self.dataset_name = dataset_name | ||
|
||
if self.dataset_name is None: | ||
if dataset_name is None: | ||
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) | ||
else: | ||
self.dataset_name = dataset_name | ||
|
||
if not hasattr(train_tensors[0], 'shape'): | ||
type_check(train_tensors, val_tensors) | ||
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors | ||
self.cross_validators: Dict[str, CrossValFunc] = {} | ||
self.holdout_validators: Dict[str, HoldOutFunc] = {} | ||
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {} | ||
self.random_state = np.random.RandomState(seed=seed) | ||
self.shuffle = shuffle | ||
self.resampling_strategy = resampling_strategy | ||
|
@@ -143,6 +148,8 @@ def __init__( | |
# Make sure cross validation splits are created once | ||
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) | ||
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) | ||
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes) | ||
|
||
self.splits = self.get_splits_from_resampling_strategy() | ||
|
||
# We also need to be able to transform the data, be it for pre-processing | ||
|
@@ -210,7 +217,7 @@ def __len__(self) -> int: | |
def _get_indices(self) -> np.ndarray: | ||
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self)) | ||
|
||
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]: | ||
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[List[int]]]]: | ||
""" | ||
Creates a set of splits based on a resampling strategy provided | ||
|
||
|
@@ -241,6 +248,9 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int] | |
num_splits=cast(int, num_splits), | ||
) | ||
) | ||
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes): | ||
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state, | ||
self._get_indices()), None)) | ||
else: | ||
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}") | ||
return splits | ||
|
@@ -312,22 +322,29 @@ def create_holdout_val_split( | |
self.random_state, val_share, self._get_indices(), **kwargs) | ||
return train, val | ||
|
||
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: | ||
def get_dataset(self, split_id: int, train: bool) -> Dataset: | ||
""" | ||
The above split methods employ the Subset to internally subsample the whole dataset. | ||
|
||
During training, we need access to one of those splits. This is a handy function | ||
to provide training data to fit a pipeline | ||
|
||
Args: | ||
split (int): The desired subset of the dataset to split and use | ||
split_id (int): which split id to get from the splits | ||
train (bool): whether the dataset is required for training or evaluating. | ||
|
||
Returns: | ||
Dataset: the reduced dataset to be used for testing | ||
""" | ||
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple | ||
return (TransformSubset(self, self.splits[split_id][0], train=True), | ||
TransformSubset(self, self.splits[split_id][1], train=False)) | ||
if split_id >= len(self.splits): # old version: split_id > len(self.splits) | ||
raise IndexError(f"self.splits index out of range, got split_id={split_id}" | ||
f" (>= num_splits={len(self.splits)})") | ||
indices = self.splits[split_id][int(not train)] # 0: for training, 1: for evaluation | ||
if indices is None: | ||
raise ValueError("Specified fold (or subset) does not exist") | ||
Comment on lines
+344
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you cover this line by a test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
|
||
return TransformSubset(self, indices, train=train) | ||
|
||
def replace_data(self, X_train: BaseDatasetInputType, | ||
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset': | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.