Skip to content

Commit 9168bd3

Browse files
committed
Add dropout shape as a hyperparameter (#213)
* Add dropout shape as a hyperparameter * fix stupid bug
1 parent b72d8c6 commit 9168bd3

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
4545
# n_units for the architecture, since, it is mostly implemented for the
4646
# output layer, which is part of the head and not of the backbone.
4747
dropout_shape = get_shaped_neuron_counts(
48-
shape=self.config['resnet_shape'],
49-
in_feat=0,
50-
out_feat=0,
51-
max_neurons=self.config["max_dropout"],
52-
layer_count=self.config['num_groups'] + 1,
53-
)[:-1]
48+
self.config['dropout_shape'], 0, 0, 1000, self.config['num_groups']
49+
)
50+
51+
dropout_shape = [
52+
dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape
53+
]
5454

5555
self.config.update(
5656
{"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)}
@@ -136,6 +136,13 @@ def get_hyperparameter_search_space( # type: ignore[override]
136136
max_dropout: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="max_dropout",
137137
value_range=(0, 0.8),
138138
default_value=0.5),
139+
dropout_shape: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="dropout_shape",
140+
value_range=('funnel', 'long_funnel',
141+
'diamond', 'hexagon',
142+
'brick', 'triangle',
143+
'stairs'),
144+
default_value='funnel',
145+
),
139146
max_shake_drop_probability: HyperparameterSearchSpace = HyperparameterSearchSpace(
140147
hyperparameter="max_shake_drop_probability",
141148
value_range=(0, 1),
@@ -165,8 +172,10 @@ def get_hyperparameter_search_space( # type: ignore[override]
165172

166173
if dropout_flag:
167174
max_dropout = get_hyperparameter(max_dropout, UniformFloatHyperparameter)
168-
cs.add_hyperparameter(max_dropout)
175+
dropout_shape = get_hyperparameter(dropout_shape, CategoricalHyperparameter)
176+
cs.add_hyperparameters([dropout_shape, max_dropout])
169177
cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True))
178+
cs.add_condition(CS.EqualsCondition(dropout_shape, use_dropout, True))
170179

171180
skip_connection_flag = False
172181
if any(use_skip_connection.value_range):

0 commit comments

Comments
 (0)