Skip to content

Commit 38cf3c7

Browse files
authored
fix bug in adversarial trainer (#207)
1 parent ad60a0a commit 38cf3c7

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,22 +196,35 @@ def get_hyperparameter_search_space(
196196

197197
add_hyperparameter(cs, epsilon, UniformFloatHyperparameter)
198198
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
199+
snapshot_ensemble_flag = False
200+
if any(use_snapshot_ensemble.value_range):
201+
snapshot_ensemble_flag = True
202+
199203
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
200-
se_lastk = get_hyperparameter(se_lastk, Constant)
201-
cs.add_hyperparameters([use_snapshot_ensemble, se_lastk])
202-
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
203-
cs.add_condition(cond)
204+
cs.add_hyperparameter(use_snapshot_ensemble)
205+
206+
if snapshot_ensemble_flag:
207+
se_lastk = get_hyperparameter(se_lastk, Constant)
208+
cs.add_hyperparameter(se_lastk)
209+
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
210+
cs.add_condition(cond)
211+
212+
lookahead_flag = False
213+
if any(use_lookahead_optimizer.value_range):
214+
lookahead_flag = True
204215

205216
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
206217
cs.add_hyperparameter(use_lookahead_optimizer)
207-
la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps,
208-
la_alpha=la_alpha)
209-
parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True}
210-
cs.add_configuration_space(
211-
Lookahead.__name__,
212-
la_config_space,
213-
parent_hyperparameter=parent_hyperparameter
214-
)
218+
219+
if lookahead_flag:
220+
la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps,
221+
la_alpha=la_alpha)
222+
parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True}
223+
cs.add_configuration_space(
224+
Lookahead.__name__,
225+
la_config_space,
226+
parent_hyperparameter=parent_hyperparameter
227+
)
215228

216229
"""
217230
if dataset_properties is not None:

0 commit comments

Comments
 (0)