Skip to content

Commit c4b7729

Browse files
committed
Fix bugs in cutout training (#233)
* Fix bugs in cutout training * Address comments from arlind
1 parent 622c185 commit c4b7729

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3535
if beta <= 0 or r > self.alpha:
3636
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3737

38-
# The mixup component mixes up also on the batch dimension
39-
# It is unlikely that the batch size is lower than the number of features, but
40-
# be safe
41-
size = min(X.shape[0], X.shape[1])
42-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
38+
size = X.shape[1]
39+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40+
replace=False))
4341

4442
X[:, indices] = X[index, :][:, indices]
4543

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
NUMERICAL_VALUE = 0
13+
CATEGORICAL_VALUE = -1
1214

1315
def data_preparation(self, X: np.ndarray, y: np.ndarray,
1416
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
@@ -34,17 +36,26 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3436
lam = 1
3537
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3638

37-
# The mixup component mixes up also on the batch dimension
38-
# It is unlikely that the batch size is lower than the number of features, but
39-
# be safe
40-
size = min(X.shape[0], X.shape[1])
41-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
39+
size = X.shape[1]
40+
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41+
replace=False)
4242

43-
# We use an ordinal encoder on the tabular data
43+
if not isinstance(self.numerical_columns, typing.Iterable):
44+
raise ValueError("{} requires numerical columns information of {}"
45+
"to prepare data got {}.".format(self.__class__.__name__,
46+
typing.Iterable,
47+
self.numerical_columns))
48+
numerical_indices = torch.tensor(self.numerical_columns)
49+
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
50+
51+
# We use an ordinal encoder on the categorical columns of tabular data
4452
# -1 is the conceptual equivalent to 0 in a image, that does not
4553
# have color as a feature and hence the network has to learn to deal
46-
# without this data
47-
X[:, indices.long()] = -1
54+
# without this data. For numerical columns we use 0 to cutout the features
55+
# similar to the effect that setting 0 as a pixel value in an image.
56+
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
57+
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
58+
4859
lam = 1
4960
y_a = y
5061
y_b = y

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def prepare_trainer(self, X: Dict) -> None:
319319
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
320320
labels=labels,
321321
step_interval=X['step_interval']
322+
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
323+
'dataset_properties'] else None
322324
)
323325

324326
def get_budget_tracker(self, X: Dict) -> BudgetTracker:
@@ -396,11 +398,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
396398

397399
val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {}
398400
if self.eval_valid_each_epoch(X):
399-
<<<<<<< HEAD
400-
if X['val_data_loader']:
401-
=======
402401
if 'val_data_loader' in X and X['val_data_loader']:
403-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
404402
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
405403
if 'test_data_loader' in X and X['test_data_loader']:
406404
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'], epoch, writer)
@@ -454,17 +452,10 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
454452

455453
# wrap up -- add score if not evaluating every epoch
456454
if not self.eval_valid_each_epoch(X):
457-
<<<<<<< HEAD
458-
if X['val_data_loader']:
459-
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
460-
if 'test_data_loader' in X and X['val_data_loader']:
461-
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'], epoch, writer)
462-
=======
463455
if 'val_data_loader' in X and X['val_data_loader']:
464456
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
465457
if 'test_data_loader' in X and X['test_data_loader']:
466458
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'])
467-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
468459
self.run_summary.add_performance(
469460
epoch=epoch,
470461
start_time=start_time,

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def prepare(
273273
task_type: int,
274274
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame],
275275
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch,
276+
numerical_columns: Optional[List[int]] = None,
276277
**kwargs: Dict
277278
) -> None:
278279

@@ -330,6 +331,9 @@ def prepare(
330331
# task type (used for calculating metrics)
331332
self.task_type = task_type
332333

334+
# for cutout trainer, we need the list of numerical columns
335+
self.numerical_columns = numerical_columns
336+
333337
def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
334338
"""
335339
Optional place holder for AutoPytorch Extensions.

0 commit comments

Comments
 (0)