Skip to content

Commit ee685a9

Browse files
authored
Cocktail hotfixes (#245)
* Fixes for the development branch and regularization cocktails * Update implementation * Fix unit tests temporarily * Implementation update and bug fixes * Removing unecessary code * Addressing Ravin's comments
1 parent 463c166 commit ee685a9

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None:
4141
out_features=self.config["num_units_%d" % i],
4242
blocks_per_group=self.config["blocks_per_group_%d" % i],
4343
last_block_index=(i - 1) * self.config["blocks_per_group_%d" % i],
44-
dropout=self.config['use_dropout']
44+
dropout=self.config[f'dropout_{i}'] if self.config['use_dropout'] else None,
4545
)
4646
)
4747
if self.config['use_batch_norm']:
@@ -52,7 +52,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None:
5252
return backbone
5353

5454
def _add_group(self, in_features: int, out_features: int,
55-
blocks_per_group: int, last_block_index: int, dropout: bool
55+
blocks_per_group: int, last_block_index: int, dropout: Optional[float]
5656
) -> nn.Module:
5757
"""
5858
Adds a group into the main backbone.
@@ -64,7 +64,8 @@ def _add_group(self, in_features: int, out_features: int,
6464
out_features (int): output dimensionality for the current block
6565
blocks_per_group (int): Number of ResNet per group
6666
last_block_index (int): block index for shake regularization
67-
dropout (bool): whether or not use dropout
67+
dropout (None, float): dropout value for the group. If none,
68+
no dropout is applied.
6869
"""
6970
blocks = list()
7071
for i in range(blocks_per_group):
@@ -245,7 +246,7 @@ def __init__(
245246
out_features: int,
246247
blocks_per_group: int,
247248
block_index: int,
248-
dropout: bool,
249+
dropout: Optional[float],
249250
activation: nn.Module
250251
):
251252
super(ResBlock, self).__init__()
@@ -289,13 +290,22 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module:
289290
if self.config['use_batch_norm']:
290291
layers.append(nn.BatchNorm1d(in_features))
291292
layers.append(self.activation())
293+
else:
294+
# if start norm is not None and skip connection is None
295+
# we will never apply the start_norm for the first layer in the block,
296+
# which is why we should account for this case.
297+
if not self.config['use_skip_connection']:
298+
if self.config['use_batch_norm']:
299+
layers.append(nn.BatchNorm1d(in_features))
300+
layers.append(self.activation())
301+
292302
layers.append(nn.Linear(in_features, out_features))
293303

294304
if self.config['use_batch_norm']:
295305
layers.append(nn.BatchNorm1d(out_features))
296306
layers.append(self.activation())
297307

298-
if self.config["use_dropout"]:
308+
if self.dropout is not None:
299309
layers.append(nn.Dropout(self.dropout))
300310
layers.append(nn.Linear(out_features, out_features))
301311

@@ -320,6 +330,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
320330
if self.config["use_skip_connection"]:
321331
residual = self.shortcut(x)
322332

333+
# TODO make the below code better
323334
if self.config["use_skip_connection"]:
324335
if self.config["multi_branch_choice"] == 'shake-shake':
325336
x1 = self.layers(x)

autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,28 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None:
3030
out_features = self.config["output_dim"]
3131

3232
# use the get_shaped_neuron_counts to update the number of units
33-
neuron_counts = get_shaped_neuron_counts(self.config['resnet_shape'],
34-
in_features,
35-
out_features,
36-
self.config['max_units'],
37-
self.config['num_groups'] + 2)[:-1]
33+
neuron_counts = get_shaped_neuron_counts(
34+
self.config['resnet_shape'],
35+
in_features,
36+
out_features,
37+
self.config['max_units'],
38+
self.config['num_groups'] + 2,
39+
)[:-1]
3840
self.config.update(
3941
{"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)}
4042
)
41-
if self.config['use_dropout'] and self.config["max_dropout"] > 0.05:
43+
if self.config['use_dropout']:
44+
# the last dropout ("neuron") value is skipped since it will be equal
45+
# to output_feat, which is 0. This is also skipped when getting the
46+
# nr of units for the architecture, since, it is mostly implemented for the
47+
# output layer, which is part of the head and not of the backbone.
4248
dropout_shape = get_shaped_neuron_counts(
43-
self.config['dropout_shape'], 0, 0, 1000, self.config['num_groups']
44-
)
45-
46-
dropout_shape = [
47-
dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape
48-
]
49+
self.config['dropout_shape'],
50+
0,
51+
0,
52+
self.config["max_dropout"],
53+
self.config['num_groups'] + 1,
54+
)[:-1]
4955

5056
self.config.update(
5157
{"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)}
@@ -61,7 +67,7 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> None:
6167
out_features=self.config["num_units_%d" % i],
6268
blocks_per_group=self.config["blocks_per_group"],
6369
last_block_index=(i - 1) * self.config["blocks_per_group"],
64-
dropout=self.config['use_dropout']
70+
dropout=self.config[f'dropout_{i}'] if self.config['use_dropout'] else None,
6571
)
6672
)
6773
if self.config['use_batch_norm']:

autoPyTorch/pipeline/components/setup/network_head/no_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class NoHead(NetworkHeadComponent):
2020
"""
2121

2222
def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
23-
layers = [nn.Flatten()]
23+
layers = []
2424
in_features = np.prod(input_shape).item()
2525
out_features = np.prod(output_shape).item()
2626
layers.append(_activations[self.config["activation"]]())
@@ -34,8 +34,8 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[
3434
'shortname': 'NoHead',
3535
'name': 'NoHead',
3636
'handles_tabular': True,
37-
'handles_image': True,
38-
'handles_time_series': True,
37+
'handles_image': False,
38+
'handles_time_series': False,
3939
}
4040

4141
@staticmethod

test/test_pipeline/components/setup/test_setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ def test_add_network_backbone(self):
422422
class TestNetworkHead:
423423
def test_all_heads_available(self):
424424
network_head_choice = NetworkHeadChoice(dataset_properties={})
425-
426-
assert len(network_head_choice.get_components().keys()) == 2
425+
assert len(network_head_choice.get_components().keys()) == 3
427426

428427
@pytest.mark.parametrize('task_type_input_output_shape', [(constants.IMAGE_CLASSIFICATION, (3, 64, 64), (5,)),
429428
(constants.IMAGE_REGRESSION, (3, 64, 64), (1,)),
@@ -441,7 +440,9 @@ def test_dummy_forward_backward_pass(self, task_type_input_output_shape):
441440
if task_type in constants.CLASSIFICATION_TASKS:
442441
dataset_properties["num_classes"] = output_shape[0]
443442

444-
cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties)
443+
cs = network_head_choice.get_hyperparameter_search_space(
444+
dataset_properties=dataset_properties,
445+
)
445446
# test 10 random configurations
446447
for i in range(10):
447448
config = cs.sample_configuration()

0 commit comments

Comments
 (0)