Skip to content

Commit e232f61

Browse files
committed
Additional metrics during train (#194)
* Added additional metrics to fit dictionary * Added in test also Fix mypy and flake after rebase, added random state to mixup and cutout and changs no resampling for new code fix bug in setup.py
1 parent e9b9458 commit e232f61

File tree

6 files changed

+24
-49
lines changed

6 files changed

+24
-49
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,7 @@ def __init__(
8282
dataset_name: Optional[str] = None,
8383
val_tensors: Optional[BaseDatasetInputType] = None,
8484
test_tensors: Optional[BaseDatasetInputType] = None,
85-
<<<<<<< HEAD
8685
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
87-
=======
88-
resampling_strategy: Union[CrossValTypes,
89-
HoldoutValTypes,
90-
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
91-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
9286
resampling_strategy_args: Optional[Dict[str, Any]] = None,
9387
shuffle: Optional[bool] = True,
9488
seed: Optional[int] = 42,
@@ -105,12 +99,7 @@ def __init__(
10599
validation data
106100
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
107101
test data
108-
<<<<<<< HEAD
109102
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
110-
=======
111-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
112-
(default=HoldoutValTypes.holdout_validation):
113-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
114103
strategy to split the training data.
115104
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
116105
required for the chosen resampling strategy. If None, uses
@@ -132,17 +121,11 @@ def __init__(
132121
if not hasattr(train_tensors[0], 'shape'):
133122
type_check(train_tensors, val_tensors)
134123
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
135-
<<<<<<< HEAD
136124
self.cross_validators: Dict[str, CrossValFunc] = {}
137125
self.holdout_validators: Dict[str, HoldOutFunc] = {}
138126
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
139127
self.random_state = np.random.RandomState(seed=seed)
140-
=======
141-
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
142-
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
143-
self.no_resampling_validators: Dict[str, NO_RESAMPLING_FN] = {}
144-
self.rng = np.random.RandomState(seed=seed)
145-
>>>>>>> Fix mypy and flake
128+
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
146129
self.shuffle = shuffle
147130
self.resampling_strategy = resampling_strategy
148131
self.resampling_strategy_args = resampling_strategy_args
@@ -167,11 +150,8 @@ def __init__(
167150
# Make sure cross validation splits are created once
168151
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
169152
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
170-
<<<<<<< HEAD
153+
171154
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes)
172-
=======
173-
self.no_resampling_validators = get_no_resampling_validators(*NoResamplingStrategyTypes)
174-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
175155

176156
self.splits = self.get_splits_from_resampling_strategy()
177157

@@ -272,12 +252,8 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[
272252
)
273253
)
274254
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
275-
<<<<<<< HEAD
276255
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state,
277256
self._get_indices()), None))
278-
=======
279-
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self._get_indices()), None))
280-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
281257
else:
282258
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
283259
return splits
@@ -349,11 +325,7 @@ def create_holdout_val_split(
349325
self.random_state, val_share, self._get_indices(), **kwargs)
350326
return train, val
351327

352-
<<<<<<< HEAD
353328
def get_dataset(self, split_id: int, train: bool) -> Dataset:
354-
=======
355-
def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
356-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
357329
"""
358330
The above split methods employ the Subset to internally subsample the whole dataset.
359331
@@ -368,7 +340,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
368340
Dataset: the reduced dataset to be used for testing
369341
"""
370342
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
371-
<<<<<<< HEAD
372343
if split_id >= len(self.splits): # old version: split_id > len(self.splits)
373344
raise IndexError(f"self.splits index out of range, got split_id={split_id}"
374345
f" (>= num_splits={len(self.splits)})")
@@ -377,9 +348,6 @@ def get_dataset_for_training(self, split_id: int, train: bool) -> Dataset:
377348
raise ValueError("Specified fold (or subset) does not exist")
378349

379350
return TransformSubset(self, indices, train=train)
380-
=======
381-
return TransformSubset(self, self.splits[split_id][0], train=train)
382-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
383351

384352
def replace_data(self, X_train: BaseDatasetInputType,
385353
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3939
...
4040

4141

42-
class NO_RESAMPLING_FN(Protocol):
43-
def __call__(self, indices: np.ndarray) -> np.ndarray:
42+
class NoResamplingFunc(Protocol):
43+
def __call__(self,
44+
random_state: np.random.RandomState,
45+
indices: np.ndarray) -> np.ndarray:
4446
...
4547

4648

@@ -90,22 +92,13 @@ def is_stratified(self) -> bool:
9092

9193
class NoResamplingStrategyTypes(IntEnum):
9294
no_resampling = 8
93-
<<<<<<< HEAD
9495

9596
def is_stratified(self) -> bool:
9697
return False
9798

9899

99100
# TODO: replace it with another way
100101
ResamplingStrategies = Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
101-
=======
102-
shuffle_no_resampling = 9
103-
104-
105-
# TODO: replace it with another way
106-
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
107-
108-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
109102

110103
DEFAULT_RESAMPLING_PARAMETERS: Dict[
111104
ResamplingStrategies,

autoPyTorch/pipeline/components/training/trainer/cutout_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -35,7 +37,12 @@ def __init__(self, patch_ratio: float,
3537
"""
3638
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3739
self.weighted_loss = weighted_loss
38-
self.random_state = random_state
40+
if random_state is None:
41+
# A trainer components need a random state for
42+
# sampling -- for example in MixUp training
43+
self.random_state = check_random_state(1)
44+
else:
45+
self.random_state = random_state
3946
self.use_snapshot_ensemble = use_snapshot_ensemble
4047
self.se_lastk = se_lastk
4148
self.use_lookahead_optimizer = use_lookahead_optimizer

autoPyTorch/pipeline/components/training/trainer/mixup_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import numpy as np
1212

13+
from sklearn.utils import check_random_state
14+
1315
from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
1416
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
1517
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
@@ -34,7 +36,12 @@ def __init__(self, alpha: float,
3436
"""
3537
self.use_stochastic_weight_averaging = use_stochastic_weight_averaging
3638
self.weighted_loss = weighted_loss
37-
self.random_state = random_state
39+
if random_state is None:
40+
# A trainer components need a random state for
41+
# sampling -- for example in MixUp training
42+
self.random_state = check_random_state(1)
43+
else:
44+
self.random_state = random_state
3845
self.use_snapshot_ensemble = use_snapshot_ensemble
3946
self.se_lastk = se_lastk
4047
self.use_lookahead_optimizer = use_lookahead_optimizer

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"pre-commit",
5959
"pytest-cov",
6060
'pytest-forked',
61-
"pytest-mock"
61+
"pytest-mock",
6262
"codecov",
6363
"pep8",
6464
"mypy",

test/test_api/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from autoPyTorch.pipeline.components.setup.traditional_ml.traditional_learner import _traditional_learners
3737
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
3838

39-
from test.test_api.api_utils import print_debug_information
39+
from test.test_api.api_utils import print_debug_information # noqa E402
4040

4141

4242
CV_NUM_SPLITS = 2

0 commit comments

Comments
 (0)