Skip to content

Commit fc5f2d6

Browse files
ArlindKadraravinkohli
authored andcommitted
Cocktail hotfixes (#245)
* Fixes for the development branch and regularization cocktails * Update implementation * Fix unit tests temporarily * Implementation update and bug fixes * Removing unecessary code * Addressing Ravin's comments
1 parent a41a2a3 commit fc5f2d6

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def _add_group(self, in_features: int, out_features: int,
6464
out_features (int): output dimensionality for the current block
6565
blocks_per_group (int): Number of ResNet per group
6666
last_block_index (int): block index for shake regularization
67-
dropout (bool): whether or not use dropout
67+
dropout (None, float): dropout value for the group. If none,
68+
no dropout is applied.
6869
"""
6970
blocks = list()
7071
for i in range(blocks_per_group):
@@ -290,13 +291,22 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module:
290291
if self.config['use_batch_norm']:
291292
layers.append(nn.BatchNorm1d(in_features))
292293
layers.append(self.activation())
294+
else:
295+
# if start norm is not None and skip connection is None
296+
# we will never apply the start_norm for the first layer in the block,
297+
# which is why we should account for this case.
298+
if not self.config['use_skip_connection']:
299+
if self.config['use_batch_norm']:
300+
layers.append(nn.BatchNorm1d(in_features))
301+
layers.append(self.activation())
302+
293303
layers.append(nn.Linear(in_features, out_features))
294304

295305
if self.config['use_batch_norm']:
296306
layers.append(nn.BatchNorm1d(out_features))
297307
layers.append(self.activation())
298308

299-
if self.config["use_dropout"]:
309+
if self.dropout is not None:
300310
layers.append(nn.Dropout(self.dropout))
301311
layers.append(nn.Linear(out_features, out_features))
302312

@@ -321,6 +331,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
321331
if self.config["use_skip_connection"]:
322332
residual = self.shortcut(x)
323333

334+
# TODO make the below code better
324335
if self.config["use_skip_connection"]:
325336
if self.config["multi_branch_choice"] == 'shake-shake':
326337
x1 = self.layers(x)

autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
3131
out_features = self.config["output_dim"]
3232

3333
# use the get_shaped_neuron_counts to update the number of units
34-
neuron_counts = get_shaped_neuron_counts(self.config['resnet_shape'],
35-
in_features,
36-
out_features,
37-
self.config['max_units'],
38-
self.config['num_groups'] + 2)[:-1]
34+
neuron_counts = get_shaped_neuron_counts(
35+
self.config['resnet_shape'],
36+
in_features,
37+
out_features,
38+
self.config['max_units'],
39+
self.config['num_groups'] + 2,
40+
)[:-1]
3941
self.config.update(
4042
{"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)}
4143
)
@@ -45,12 +47,12 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
4547
# n_units for the architecture, since, it is mostly implemented for the
4648
# output layer, which is part of the head and not of the backbone.
4749
dropout_shape = get_shaped_neuron_counts(
48-
self.config['dropout_shape'], 0, 0, 1000, self.config['num_groups']
49-
)
50-
51-
dropout_shape = [
52-
dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape
53-
]
50+
self.config['dropout_shape'],
51+
0,
52+
0,
53+
self.config["max_dropout"],
54+
self.config['num_groups'] + 1,
55+
)[:-1]
5456

5557
self.config.update(
5658
{"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)}

autoPyTorch/pipeline/components/setup/network_head/no_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class NoHead(NetworkHeadComponent):
2020
"""
2121

2222
def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
23-
layers = [nn.Flatten()]
23+
layers = []
2424
in_features = np.prod(input_shape).item()
2525
out_features = np.prod(output_shape).item()
2626
layers.append(_activations[self.config["activation"]]())
@@ -34,8 +34,8 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[
3434
'shortname': 'NoHead',
3535
'name': 'NoHead',
3636
'handles_tabular': True,
37-
'handles_image': True,
38-
'handles_time_series': True,
37+
'handles_image': False,
38+
'handles_time_series': False,
3939
}
4040

4141
@staticmethod

test/test_pipeline/components/setup/test_setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,7 @@ def test_dropout(self, resnet_shape):
501501
class TestNetworkHead:
502502
def test_all_heads_available(self):
503503
network_head_choice = NetworkHeadChoice(dataset_properties={})
504-
505-
assert len(network_head_choice.get_components().keys()) == 2
504+
assert len(network_head_choice.get_components().keys()) == 3
506505

507506
@pytest.mark.parametrize('task_type_input_output_shape', [(constants.IMAGE_CLASSIFICATION, (3, 64, 64), (5,)),
508507
(constants.IMAGE_REGRESSION, (3, 64, 64), (1,)),
@@ -520,7 +519,9 @@ def test_dummy_forward_backward_pass(self, task_type_input_output_shape):
520519
if task_type in constants.CLASSIFICATION_TASKS:
521520
dataset_properties["num_classes"] = output_shape[0]
522521

523-
cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties)
522+
cs = network_head_choice.get_hyperparameter_search_space(
523+
dataset_properties=dataset_properties,
524+
)
524525
# test 10 random configurations
525526
for _ in range(10):
526527
config = cs.sample_configuration()

0 commit comments

Comments
 (0)