@@ -289,6 +289,29 @@ def _get_hyperparameter_search_space(self,
289
289
raise ValueError ("Cannot find a legal default configuration" )
290
290
cs .get_hyperparameter ('network_embedding:__choice__' ).default_value = default
291
291
292
+ # Disable CyclicLR until todo is completed.
293
+ if 'lr_scheduler' in self .named_steps .keys () and 'trainer' in self .named_steps .keys ():
294
+ trainers = cs .get_hyperparameter ('trainer:__choice__' ).choices
295
+ for trainer in trainers :
296
+ available_schedulers = self .named_steps ['lr_scheduler' ].get_available_components (
297
+ dataset_properties = dataset_properties ,
298
+ exclude = exclude if bool (exclude ) else None ,
299
+ include = include if bool (include ) else None )
300
+ # TODO: update cyclic lr to use n_restarts and adjust according to batch size
301
+ cyclic_lr_name = 'CyclicLR'
302
+ if cyclic_lr_name in available_schedulers :
303
+ # disable snapshot ensembles and stochastic weight averaging
304
+ cs .add_forbidden_clause (ForbiddenAndConjunction (
305
+ ForbiddenEqualsClause (cs .get_hyperparameter (
306
+ f'trainer:{ trainer } :use_snapshot_ensemble' ), True ),
307
+ ForbiddenEqualsClause (cs .get_hyperparameter ('lr_scheduler:__choice__' ), cyclic_lr_name )
308
+ ))
309
+ cs .add_forbidden_clause (ForbiddenAndConjunction (
310
+ ForbiddenEqualsClause (cs .get_hyperparameter (
311
+ f'trainer:{ trainer } :use_stochastic_weight_averaging' ), True ),
312
+ ForbiddenEqualsClause (cs .get_hyperparameter ('lr_scheduler:__choice__' ), cyclic_lr_name )
313
+ ))
314
+
292
315
self .configuration_space = cs
293
316
self .dataset_properties = dataset_properties
294
317
return cs
0 commit comments