Skip to content

Commit 8fb0bc2

Browse files
authored
Shake Shake updates (#287)
* To test locally * fix bug in trainer choice fit * fix ensemble bug * Correct bug in cleanup * To test locally * Cleanup for removing time debug statements * ablation for adversarial * shuffle false in dataloader * drop last false in dataloader * fix bug for validation set, and cutout and cutmix * To test locally * shuffle = False * To test locally * updates to search space * updates to search space * update branch with search space * undo search space update * fix bug in shake shake flag * limit to shake-even * restrict to even even * Add even even and others for shake-drop also * fix bug in passing alpha beta method * restrict to only even even * fix silly bug: * remove imputer and ordinal encoder for categorical transformer in feature validator * Address comments from shuhei
1 parent 209a4e8 commit 8fb0bc2

File tree

8 files changed

+88
-33
lines changed

8 files changed

+88
-33
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,26 @@ def get_tabular_preprocessors():
4141
preprocessors['numerical'] = list()
4242
preprocessors['categorical'] = list()
4343

44+
# preprocessors['categorical'].append(SimpleImputer(strategy='constant',
45+
# # Train data is numpy
46+
# # as of this point, where
47+
# # Ordinal Encoding is using
48+
# # for categorical. Only
49+
# # Numbers are allowed
50+
# # fill_value='!missing!',
51+
# fill_value=-1,
52+
# copy=False))
53+
54+
# preprocessors['categorical'].append(OrdinalEncoder(
55+
# handle_unknown='use_encoded_value',
56+
# unknown_value=-1))
57+
4458
preprocessors['categorical'].append(OneHotEncoder(
4559
categories='auto',
4660
sparse=False,
4761
handle_unknown='ignore'))
48-
preprocessors['categorical'].append(SimpleImputer(strategy='constant',
49-
# Train data is numpy
50-
# as of this point, where
51-
# Ordinal Encoding is using
52-
# for categorical. Only
53-
# Numbers are allowed
54-
# fill_value='!missing!',
55-
fill_value=-1,
56-
copy=False))
57-
58-
preprocessors['categorical'].append(OrdinalEncoder(
59-
handle_unknown='use_encoded_value',
60-
unknown_value=-1))
61-
6262
preprocessors['numerical'].append(SimpleImputer(strategy='median',
63-
copy=False))
63+
copy=False))
6464
preprocessors['numerical'].append(StandardScaler(with_mean=True, with_std=True, copy=False))
6565

6666
return preprocessors

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,13 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
451451
continue
452452
raise ValueError("Unknown hyperparameter for component {}. "
453453
"Expected update hyperparameter "
454-
"to be in {} got {}".format(node.__class__.__name__,
454+
"to be in {} got {}. choice is {}".format(node.__class__.__name__,
455455
component.
456456
get_hyperparameter_search_space(
457457
dataset_properties=self.dataset_properties).
458458
get_hyperparameter_names(),
459-
split_hyperparameter[1]))
459+
split_hyperparameter[1],
460+
component.__name__))
460461
else:
461462
if update.hyperparameter not in node.get_hyperparameter_search_space(
462463
dataset_properties=self.dataset_properties):

autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ def get_hyperparameter_search_space(
139139
value_range=(True, False),
140140
default_value=True,
141141
),
142+
shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace(
143+
hyperparameter="shake_alpha_beta_method",
144+
value_range=('shake-shake',
145+
'shake-even',
146+
'even-even',
147+
'M3'),
148+
default_value='shake-shake',
149+
),
142150
use_shake_drop: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="use_shake_drop",
143151
value_range=(True, False),
144152
default_value=True,
@@ -180,9 +188,8 @@ def get_hyperparameter_search_space(
180188

181189
if skip_connection_flag:
182190

183-
shake_drop_prob_flag = False
184-
if 'shake-drop' in multi_branch_choice.value_range:
185-
shake_drop_prob_flag = True
191+
shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range
192+
shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range
186193

187194
mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter)
188195
cs.add_hyperparameter(mb_choice)
@@ -192,6 +199,10 @@ def get_hyperparameter_search_space(
192199
shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter)
193200
cs.add_hyperparameter(shake_drop_prob)
194201
cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop"))
202+
if shake_shake_flag or shake_drop_prob_flag:
203+
method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter)
204+
cs.add_hyperparameter(method)
205+
cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"]))
195206

196207
# It is the upper bound of the nr of groups,
197208
# since the configuration will actually be sampled.
@@ -327,11 +338,14 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
327338
if self.config["multi_branch_choice"] == 'shake-shake':
328339
x1 = self.layers(x)
329340
x2 = self.shake_shake_layers(x)
330-
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda)
341+
alpha, beta = shake_get_alpha_beta(is_training=self.training,
342+
is_cuda=x.is_cuda,
343+
method=self.config['shake_alpha_beta_method'])
331344
x = shake_shake(x1, x2, alpha, beta)
332345
elif self.config["multi_branch_choice"] == 'shake-drop':
333346
x = self.layers(x)
334-
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda)
347+
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda,
348+
method=self.config['shake_alpha_beta_method'])
335349
bl = shake_drop_get_bl(
336350
self.block_index,
337351
1 - self.config["max_shake_drop_probability"],

autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ def get_hyperparameter_search_space( # type: ignore[override]
145145
'stairs'),
146146
default_value='funnel',
147147
),
148+
shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace(
149+
hyperparameter="shake_alpha_beta_method",
150+
value_range=('shake-shake',
151+
'shake-even',
152+
'even-even',
153+
'M3'),
154+
default_value='shake-shake',
155+
),
148156
max_shake_drop_probability: HyperparameterSearchSpace = HyperparameterSearchSpace(
149157
hyperparameter="max_shake_drop_probability",
150158
value_range=(0, 1),
@@ -188,9 +196,8 @@ def get_hyperparameter_search_space( # type: ignore[override]
188196

189197
if skip_connection_flag:
190198

191-
shake_drop_prob_flag = False
192-
if 'shake-drop' in multi_branch_choice.value_range:
193-
shake_drop_prob_flag = True
199+
shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range
200+
shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range
194201

195202
mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter)
196203
cs.add_hyperparameter(mb_choice)
@@ -200,5 +207,9 @@ def get_hyperparameter_search_space( # type: ignore[override]
200207
shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter)
201208
cs.add_hyperparameter(shake_drop_prob)
202209
cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop"))
210+
if shake_shake_flag or shake_drop_prob_flag:
211+
method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter)
212+
cs.add_hyperparameter(method)
213+
cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"]))
203214

204215
return cs

autoPyTorch/pipeline/components/setup/network_backbone/utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,35 @@ def backward(ctx: typing.Any,
9292
shake_drop = ShakeDropFunction.apply
9393

9494

95-
def shake_get_alpha_beta(is_training: bool, is_cuda: bool
96-
) -> typing.Tuple[torch.tensor, torch.tensor]:
95+
def shake_get_alpha_beta(
96+
is_training: bool,
97+
is_cuda: bool,
98+
method: str
99+
) -> typing.Tuple[torch.tensor, torch.tensor]:
100+
"""
101+
The methods used in this function have been introduced in 'ShakeShake Regularisation'
102+
https://arxiv.org/abs/1705.07485. The names have been taken from the paper as well.
103+
"""
97104
if not is_training:
98105
result = (torch.FloatTensor([0.5]), torch.FloatTensor([0.5]))
99106
return result if not is_cuda else (result[0].cuda(), result[1].cuda())
100107

101108
# TODO implement other update methods
102-
alpha = torch.rand(1)
103-
beta = torch.rand(1)
109+
if method == 'even-even':
110+
alpha = torch.FloatTensor([0.5])
111+
else:
112+
alpha = torch.rand(1)
113+
114+
if method == 'shake-shake':
115+
beta = torch.rand(1)
116+
elif method in ['shake-even', 'even-even']:
117+
beta = torch.FloatTensor([0.5])
118+
elif method == 'M3':
119+
beta = torch.FloatTensor(
120+
[torch.rand(1)*(0.5 - alpha)*alpha if alpha < 0.5 else torch.rand(1)*(alpha - 0.5)*alpha]
121+
)
122+
else:
123+
raise ValueError("Unknown method for ShakeShakeRegularisation in NetworkBackbone")
104124

105125
if is_cuda:
106126
alpha = alpha.cuda()

autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def get_hyperparameter_search_space(
9595
default_value=True,
9696
),
9797
weight_decay: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="weight_decay",
98-
value_range=(1E-7, 0.1),
98+
value_range=(1E-5, 0.1),
9999
default_value=1E-4,
100-
log=True),
100+
log=False),
101101
) -> ConfigurationSpace:
102102
cs = ConfigurationSpace()
103103

autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader:
112112
self.train_data_loader = torch.utils.data.DataLoader(
113113
train_dataset,
114114
batch_size=min(self.batch_size, len(train_dataset)),
115-
shuffle=False,
115+
shuffle=True,
116116
num_workers=X.get('num_workers', 0),
117117
pin_memory=X.get('pin_memory', True),
118118
drop_last=X.get('drop_last', False),

examples/tabular/40_advanced/example_custom_configuration_space.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def get_search_space_updates():
5454
hyperparameter='ResNetBackbone:dropout',
5555
value_range=[0, 0.5],
5656
default_value=0.2)
57+
updates.append(node_name='network_backbone',
58+
hyperparameter='ResNetBackbone:multi_branch_choice',
59+
value_range=['shake-shake'],
60+
default_value='shake-shake')
61+
updates.append(node_name='network_backbone',
62+
hyperparameter='ResNetBackbone:shake_shake_method',
63+
value_range=['M3'],
64+
default_value='M3'
65+
)
5766
return updates
5867

5968

@@ -74,7 +83,7 @@ def get_search_space_updates():
7483
# ==================================================
7584
api = TabularClassificationTask(
7685
search_space_updates=get_search_space_updates(),
77-
include_components={'network_backbone': ['MLPBackbone', 'ResNetBackbone'],
86+
include_components={'network_backbone': ['ResNetBackbone'],
7887
'encoder': ['OneHotEncoder']}
7988
)
8089

0 commit comments

Comments
 (0)