Skip to content

Commit 989f3ac

Browse files
committed
add forbidden condition cyclic lr
1 parent 8bf6280 commit 989f3ac

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

autoPyTorch/pipeline/tabular_classification.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,29 @@ def _get_hyperparameter_search_space(self,
289289
raise ValueError("Cannot find a legal default configuration")
290290
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
291291

292+
# Disable CyclicLR until todo is completed.
293+
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
294+
trainers = cs.get_hyperparameter('trainer:__choice__').choices
295+
for trainer in trainers:
296+
available_schedulers = self.named_steps['lr_scheduler'].get_available_components(
297+
dataset_properties=dataset_properties,
298+
exclude=exclude if bool(exclude) else None,
299+
include=include if bool(include) else None)
300+
# TODO: update cyclic lr to use n_restarts and adjust according to batch size
301+
cyclic_lr_name = 'CyclicLR'
302+
if cyclic_lr_name in available_schedulers:
303+
# disable snapshot ensembles and stochastic weight averaging
304+
cs.add_forbidden_clause(ForbiddenAndConjunction(
305+
ForbiddenEqualsClause(cs.get_hyperparameter(
306+
f'trainer:{trainer}:use_snapshot_ensemble'), True),
307+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
308+
))
309+
cs.add_forbidden_clause(ForbiddenAndConjunction(
310+
ForbiddenEqualsClause(cs.get_hyperparameter(
311+
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
312+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
313+
))
314+
292315
self.configuration_space = cs
293316
self.dataset_properties = dataset_properties
294317
return cs

autoPyTorch/pipeline/tabular_regression.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,29 @@ def _get_hyperparameter_search_space(self,
238238
raise ValueError("Cannot find a legal default configuration")
239239
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
240240

241+
# Disable CyclicLR until todo is completed.
242+
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
243+
trainers = cs.get_hyperparameter('trainer:__choice__').choices
244+
for trainer in trainers:
245+
available_schedulers = self.named_steps['lr_scheduler'].get_available_components(
246+
dataset_properties=dataset_properties,
247+
exclude=exclude if bool(exclude) else None,
248+
include=include if bool(include) else None)
249+
# TODO: update cyclic lr to use n_restarts and adjust according to batch size
250+
cyclic_lr_name = 'CyclicLR'
251+
if cyclic_lr_name in available_schedulers:
252+
# disable snapshot ensembles and stochastic weight averaging
253+
cs.add_forbidden_clause(ForbiddenAndConjunction(
254+
ForbiddenEqualsClause(cs.get_hyperparameter(
255+
f'trainer:{trainer}:use_snapshot_ensemble'), True),
256+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
257+
))
258+
cs.add_forbidden_clause(ForbiddenAndConjunction(
259+
ForbiddenEqualsClause(cs.get_hyperparameter(
260+
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
261+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
262+
))
263+
241264
self.configuration_space = cs
242265
self.dataset_properties = dataset_properties
243266
return cs

0 commit comments

Comments
 (0)