@@ -45,12 +45,12 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
45
45
# n_units for the architecture, since, it is mostly implemented for the
46
46
# output layer, which is part of the head and not of the backbone.
47
47
dropout_shape = get_shaped_neuron_counts (
48
- shape = self .config ['resnet_shape ' ],
49
- in_feat = 0 ,
50
- out_feat = 0 ,
51
- max_neurons = self . config [ "max_dropout" ],
52
- layer_count = self .config ['num_groups' ] + 1 ,
53
- )[: - 1 ]
48
+ self .config ['dropout_shape ' ], 0 , 0 , 1000 , self . config [ 'num_groups' ]
49
+ )
50
+
51
+ dropout_shape = [
52
+ dropout / 1000 * self .config ["max_dropout" ] for dropout in dropout_shape
53
+ ]
54
54
55
55
self .config .update (
56
56
{"dropout_%d" % (i + 1 ): dropout for i , dropout in enumerate (dropout_shape )}
@@ -136,6 +136,13 @@ def get_hyperparameter_search_space( # type: ignore[override]
136
136
max_dropout : HyperparameterSearchSpace = HyperparameterSearchSpace (hyperparameter = "max_dropout" ,
137
137
value_range = (0 , 0.8 ),
138
138
default_value = 0.5 ),
139
+ dropout_shape : HyperparameterSearchSpace = HyperparameterSearchSpace (hyperparameter = "dropout_shape" ,
140
+ value_range = ('funnel' , 'long_funnel' ,
141
+ 'diamond' , 'hexagon' ,
142
+ 'brick' , 'triangle' ,
143
+ 'stairs' ),
144
+ default_value = 'funnel' ,
145
+ ),
139
146
max_shake_drop_probability : HyperparameterSearchSpace = HyperparameterSearchSpace (
140
147
hyperparameter = "max_shake_drop_probability" ,
141
148
value_range = (0 , 1 ),
@@ -165,8 +172,10 @@ def get_hyperparameter_search_space( # type: ignore[override]
165
172
166
173
if dropout_flag :
167
174
max_dropout = get_hyperparameter (max_dropout , UniformFloatHyperparameter )
168
- cs .add_hyperparameter (max_dropout )
175
+ dropout_shape = get_hyperparameter (dropout_shape , CategoricalHyperparameter )
176
+ cs .add_hyperparameters ([dropout_shape , max_dropout ])
169
177
cs .add_condition (CS .EqualsCondition (max_dropout , use_dropout , True ))
178
+ cs .add_condition (CS .EqualsCondition (dropout_shape , use_dropout , True ))
170
179
171
180
skip_connection_flag = False
172
181
if any (use_skip_connection .value_range ):
0 commit comments