Skip to content

Commit a831767

Browse files
committed
Fix bugs in cutout training (#233)
* Fix bugs in cutout training * Address comments from arlind
1 parent a21c2e4 commit a831767

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3535
if beta <= 0 or r > self.alpha:
3636
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3737

38-
# The mixup component mixes up also on the batch dimension
39-
# It is unlikely that the batch size is lower than the number of features, but
40-
# be safe
41-
size = min(X.shape[0], X.shape[1])
42-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
38+
size = X.shape[1]
39+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40+
replace=False))
4341

4442
X[:, indices] = X[index, :][:, indices]
4543

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
NUMERICAL_VALUE = 0
13+
CATEGORICAL_VALUE = -1
1214

1315
def data_preparation(self, X: np.ndarray, y: np.ndarray,
1416
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
@@ -34,17 +36,26 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3436
lam = 1
3537
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3638

37-
# The mixup component mixes up also on the batch dimension
38-
# It is unlikely that the batch size is lower than the number of features, but
39-
# be safe
40-
size = min(X.shape[0], X.shape[1])
41-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
39+
size = X.shape[1]
40+
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41+
replace=False)
4242

43-
# We use an ordinal encoder on the tabular data
43+
if not isinstance(self.numerical_columns, typing.Iterable):
44+
raise ValueError("{} requires numerical columns information of {}"
45+
"to prepare data got {}.".format(self.__class__.__name__,
46+
typing.Iterable,
47+
self.numerical_columns))
48+
numerical_indices = torch.tensor(self.numerical_columns)
49+
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
50+
51+
# We use an ordinal encoder on the categorical columns of tabular data
4452
# -1 is the conceptual equivalent to 0 in a image, that does not
4553
# have color as a feature and hence the network has to learn to deal
46-
# without this data
47-
X[:, indices.long()] = -1
54+
# without this data. For numerical columns we use 0 to cutout the features
55+
# similar to the effect that setting 0 as a pixel value in an image.
56+
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
57+
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
58+
4859
lam = 1
4960
y_a = y
5061
y_b = y

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
343343
scheduler=X['lr_scheduler'],
344344
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
345345
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
346-
step_interval=X['step_interval']
346+
step_interval=X['step_interval'],
347+
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
348+
'dataset_properties'] else None
347349
)
348350
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
349351
self.run_summary = RunSummary(

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,8 @@ def prepare(
265265
scheduler: _LRScheduler,
266266
task_type: int,
267267
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame],
268-
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch
268+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch,
269+
numerical_columns: Optional[List[int]] = None
269270
) -> None:
270271

271272
# Save the device to be used
@@ -322,6 +323,9 @@ def prepare(
322323
# task type (used for calculating metrics)
323324
self.task_type = task_type
324325

326+
# for cutout trainer, we need the list of numerical columns
327+
self.numerical_columns = numerical_columns
328+
325329
def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
326330
"""
327331
Optional place holder for AutoPytorch Extensions.

0 commit comments

Comments
 (0)