Skip to content

Commit 8b8ba42

Browse files
committed
Fix bugs in cutout training (#233)
* Fix bugs in cutout training * Address comments from arlind
1 parent 6a240a6 commit 8b8ba42

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3535
if beta <= 0 or r > self.alpha:
3636
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
3737

38-
# The mixup component mixes up also on the batch dimension
39-
# It is unlikely that the batch size is lower than the number of features, but
40-
# be safe
41-
size = min(X.shape[0], X.shape[1])
42-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * lam))))
38+
size = X.shape[1]
39+
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40+
replace=False))
4341

4442
X[:, indices] = X[index, :][:, indices]
4543

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
NUMERICAL_VALUE = 0
13+
CATEGORICAL_VALUE = -1
1214

1315
def data_preparation(self, X: np.ndarray, y: np.ndarray,
1416
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
@@ -34,17 +36,26 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3436
lam = 1
3537
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3638

37-
# The mixup component mixes up also on the batch dimension
38-
# It is unlikely that the batch size is lower than the number of features, but
39-
# be safe
40-
size = min(X.shape[0], X.shape[1])
41-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int(size * self.patch_ratio))))
39+
size = X.shape[1]
40+
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41+
replace=False)
4242

43-
# We use an ordinal encoder on the tabular data
43+
if not isinstance(self.numerical_columns, typing.Iterable):
44+
raise ValueError("{} requires numerical columns information of {}"
45+
"to prepare data got {}.".format(self.__class__.__name__,
46+
typing.Iterable,
47+
self.numerical_columns))
48+
numerical_indices = torch.tensor(self.numerical_columns)
49+
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
50+
51+
# We use an ordinal encoder on the categorical columns of tabular data
4452
# -1 is the conceptual equivalent to 0 in a image, that does not
4553
# have color as a feature and hence the network has to learn to deal
46-
# without this data
47-
X[:, indices.long()] = -1
54+
# without this data. For numerical columns we use 0 to cutout the features
55+
# similar to the effect that setting 0 as a pixel value in an image.
56+
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
57+
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
58+
4859
lam = 1
4960
y_a = y
5061
y_b = y

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
342342
scheduler=X['lr_scheduler'],
343343
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
344344
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
345-
step_interval=X['step_interval']
345+
step_interval=X['step_interval'],
346+
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
347+
'dataset_properties'] else None
346348
)
347349
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
348350
self.run_summary = RunSummary(

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def prepare(
253253
scheduler: _LRScheduler,
254254
task_type: int,
255255
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame],
256-
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch
256+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch,
257+
numerical_columns: Optional[List[int]] = None
257258
) -> None:
258259

259260
# Save the device to be used
@@ -310,6 +311,9 @@ def prepare(
310311
# task type (used for calculating metrics)
311312
self.task_type = task_type
312313

314+
# for cutout trainer, we need the list of numerical columns
315+
self.numerical_columns = numerical_columns
316+
313317
def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
314318
"""
315319
Optional place holder for AutoPytorch Extensions.

0 commit comments

Comments
 (0)