Skip to content

Commit ce89f92

Browse files
[FIX] SWA and SE with non cyclic schedulers (#395)
* Enable learned embeddings, fix bug with non cyclic schedulers * add forbidden condition cyclic lr * refactor base_pipeline forbidden conditions * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent 484ead4 commit ce89f92

File tree

6 files changed

+156
-142
lines changed

6 files changed

+156
-142
lines changed

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from copy import copy
12
import warnings
23
from abc import ABCMeta
34
from collections import Counter
45
from typing import Any, Dict, List, Optional, Tuple, Union
56

67
from ConfigSpace import Configuration
78
from ConfigSpace.configuration_space import ConfigurationSpace
9+
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause
810

911
import numpy as np
1012

@@ -295,6 +297,67 @@ def _get_hyperparameter_search_space(self,
295297
"""
296298
raise NotImplementedError()
297299

300+
def _add_forbidden_conditions(self, cs):
301+
"""
302+
Add forbidden conditions to ensure valid configurations.
303+
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder
304+
and CyclicLR is disabled when using stochastic weight averaging and snapshot
305+
ensembling.
306+
307+
Args:
308+
cs (ConfigurationSpace):
309+
Configuration space to which forbidden conditions are added.
310+
311+
"""
312+
313+
# Learned Entity Embedding is only valid when encoder is one hot encoder
314+
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
315+
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
316+
if 'LearnedEntityEmbedding' in embeddings:
317+
encoders = cs.get_hyperparameter('encoder:__choice__').choices
318+
possible_default_embeddings = copy(list(embeddings))
319+
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]
320+
321+
for encoder in encoders:
322+
if encoder == 'OneHotEncoder':
323+
continue
324+
while True:
325+
try:
326+
cs.add_forbidden_clause(ForbiddenAndConjunction(
327+
ForbiddenEqualsClause(cs.get_hyperparameter(
328+
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
329+
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
330+
))
331+
break
332+
except ValueError:
333+
# change the default and try again
334+
try:
335+
default = possible_default_embeddings.pop()
336+
except IndexError:
337+
raise ValueError("Cannot find a legal default configuration")
338+
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
339+
340+
# Disable CyclicLR until todo is completed.
341+
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
342+
trainers = cs.get_hyperparameter('trainer:__choice__').choices
343+
for trainer in trainers:
344+
available_schedulers = cs.get_hyperparameter('lr_scheduler:__choice__').choices
345+
# TODO: update cyclic lr to use n_restarts and adjust according to batch size
346+
cyclic_lr_name = 'CyclicLR'
347+
if cyclic_lr_name in available_schedulers:
348+
# disable snapshot ensembles and stochastic weight averaging
349+
cs.add_forbidden_clause(ForbiddenAndConjunction(
350+
ForbiddenEqualsClause(cs.get_hyperparameter(
351+
f'trainer:{trainer}:use_snapshot_ensemble'), True),
352+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
353+
))
354+
cs.add_forbidden_clause(ForbiddenAndConjunction(
355+
ForbiddenEqualsClause(cs.get_hyperparameter(
356+
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
357+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
358+
))
359+
return cs
360+
298361
def __repr__(self) -> str:
299362
"""Retrieves a str representation of the current pipeline
300363

autoPyTorch/pipeline/components/setup/network_embedding/__init__.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -148,71 +148,64 @@ def get_hyperparameter_search_space(
148148
if default is None:
149149
defaults = [
150150
'NoEmbedding',
151-
# 'LearnedEntityEmbedding',
151+
'LearnedEntityEmbedding',
152152
]
153153
for default_ in defaults:
154154
if default_ in available_embedding:
155155
default = default_
156156
break
157157

158-
# Restrict embedding to NoEmbedding until preprocessing is fixed
159-
embedding = CSH.CategoricalHyperparameter('__choice__',
160-
['NoEmbedding'],
161-
default_value=default)
158+
if isinstance(dataset_properties['categorical_columns'], list):
159+
categorical_columns = dataset_properties['categorical_columns']
160+
else:
161+
categorical_columns = []
162+
163+
updates = self._get_search_space_updates()
164+
if '__choice__' in updates.keys():
165+
choice_hyperparameter = updates['__choice__']
166+
if not set(choice_hyperparameter.value_range).issubset(available_embedding):
167+
raise ValueError("Expected given update for {} to have "
168+
"choices in {} got {}".format(self.__class__.__name__,
169+
available_embedding,
170+
choice_hyperparameter.value_range))
171+
if len(categorical_columns) == 0:
172+
assert len(choice_hyperparameter.value_range) == 1
173+
if 'NoEmbedding' not in choice_hyperparameter.value_range:
174+
raise ValueError("Provided {} in choices, however, the dataset "
175+
"is incompatible with it".format(choice_hyperparameter.value_range))
176+
embedding = CSH.CategoricalHyperparameter('__choice__',
177+
choice_hyperparameter.value_range,
178+
default_value=choice_hyperparameter.default_value)
179+
else:
180+
181+
if len(categorical_columns) == 0:
182+
default = 'NoEmbedding'
183+
if include is not None and default not in include:
184+
raise ValueError("Provided {} in include, however, the dataset "
185+
"is incompatible with it".format(include))
186+
embedding = CSH.CategoricalHyperparameter('__choice__',
187+
['NoEmbedding'],
188+
default_value=default)
189+
else:
190+
embedding = CSH.CategoricalHyperparameter('__choice__',
191+
list(available_embedding.keys()),
192+
default_value=default)
193+
162194
cs.add_hyperparameter(embedding)
195+
for name in embedding.choices:
196+
updates = self._get_search_space_updates(prefix=name)
197+
config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore
198+
**updates)
199+
parent_hyperparameter = {'parent': embedding, 'value': name}
200+
cs.add_configuration_space(
201+
name,
202+
config_space,
203+
parent_hyperparameter=parent_hyperparameter
204+
)
205+
163206
self.configuration_space_ = cs
164207
self.dataset_properties_ = dataset_properties
165208
return cs
166-
# categorical_columns = dataset_properties['categorical_columns'] \
167-
# if isinstance(dataset_properties['categorical_columns'], List) else []
168-
169-
# updates = self._get_search_space_updates()
170-
# if '__choice__' in updates.keys():
171-
# choice_hyperparameter = updates['__choice__']
172-
# if not set(choice_hyperparameter.value_range).issubset(available_embedding):
173-
# raise ValueError("Expected given update for {} to have "
174-
# "choices in {} got {}".format(self.__class__.__name__,
175-
# available_embedding,
176-
# choice_hyperparameter.value_range))
177-
# if len(categorical_columns) == 0:
178-
# assert len(choice_hyperparameter.value_range) == 1
179-
# if 'NoEmbedding' not in choice_hyperparameter.value_range:
180-
# raise ValueError("Provided {} in choices, however, the dataset "
181-
# "is incompatible with it".format(choice_hyperparameter.value_range))
182-
# embedding = CSH.CategoricalHyperparameter('__choice__',
183-
# choice_hyperparameter.value_range,
184-
# default_value=choice_hyperparameter.default_value)
185-
# else:
186-
187-
# if len(categorical_columns) == 0:
188-
# default = 'NoEmbedding'
189-
# if include is not None and default not in include:
190-
# raise ValueError("Provided {} in include, however, the dataset "
191-
# "is incompatible with it".format(include))
192-
# embedding = CSH.CategoricalHyperparameter('__choice__',
193-
# ['NoEmbedding'],
194-
# default_value=default)
195-
# else:
196-
# embedding = CSH.CategoricalHyperparameter('__choice__',
197-
# list(available_embedding.keys()),
198-
# default_value=default)
199-
200-
# cs.add_hyperparameter(embedding)
201-
# for name in embedding.choices:
202-
# updates = self._get_search_space_updates(prefix=name)
203-
# config_space = available_embedding[name].get_hyperparameter_search_space(
204-
# dataset_properties, # type: ignore
205-
# **updates)
206-
# parent_hyperparameter = {'parent': embedding, 'value': name}
207-
# cs.add_configuration_space(
208-
# name,
209-
# config_space,
210-
# parent_hyperparameter=parent_hyperparameter
211-
# )
212-
213-
# self.configuration_space_ = cs
214-
# self.dataset_properties_ = dataset_properties
215-
# return cs
216209

217210
def transform(self, X: np.ndarray) -> np.ndarray:
218211
assert self.choice is not None, "Cannot call transform before the object is initialized"

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
334334
"""
335335
pass
336336

337+
def _swa_update(self) -> None:
338+
"""
339+
perform swa model update
340+
"""
341+
if self.swa_model is None:
342+
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
343+
self.swa_model.update_parameters(self.model)
344+
self.swa_updated = True
345+
346+
def _se_update(self, epoch: int) -> None:
347+
"""
348+
Add latest model or swa_model to model snapshot ensemble
349+
Args:
350+
epoch (int):
351+
current epoch
352+
"""
353+
if self.model_snapshots is None:
354+
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled")
355+
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
356+
if is_last_epoch and self.use_stochastic_weight_averaging:
357+
model_copy = deepcopy(self.swa_model)
358+
else:
359+
model_copy = deepcopy(self.model)
360+
361+
assert model_copy is not None
362+
model_copy.cpu()
363+
self.model_snapshots.append(model_copy)
364+
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
365+
337366
def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
338367
"""
339368
Optional place holder for AutoPytorch Extensions.
@@ -344,39 +373,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
344373
if X['is_cyclic_scheduler']:
345374
if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1:
346375
if self.use_stochastic_weight_averaging:
347-
assert self.swa_model is not None, "SWA model can't be none when" \
348-
" stochastic weight averaging is enabled"
349-
self.swa_model.update_parameters(self.model)
350-
self.swa_updated = True
376+
self._swa_update()
351377
if self.use_snapshot_ensemble:
352-
assert self.model_snapshots is not None, "model snapshots container can't be " \
353-
"none when snapshot ensembling is enabled"
354-
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
355-
if is_last_epoch and self.use_stochastic_weight_averaging:
356-
model_copy = deepcopy(self.swa_model)
357-
else:
358-
model_copy = deepcopy(self.model)
359-
360-
assert model_copy is not None
361-
model_copy.cpu()
362-
self.model_snapshots.append(model_copy)
363-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
378+
self._se_update(epoch=epoch)
364379
else:
365-
if epoch > self._budget_threshold:
366-
if self.use_stochastic_weight_averaging:
367-
assert self.swa_model is not None, "SWA model can't be none when" \
368-
" stochastic weight averaging is enabled"
369-
self.swa_model.update_parameters(self.model)
370-
self.swa_updated = True
371-
if self.use_snapshot_ensemble:
372-
assert self.model_snapshots is not None, "model snapshots container can't be " \
373-
"none when snapshot ensembling is enabled"
374-
model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \
375-
else deepcopy(self.model)
376-
assert model_copy is not None
377-
model_copy.cpu()
378-
self.model_snapshots.append(model_copy)
379-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
380+
if epoch > self._budget_threshold and self.use_stochastic_weight_averaging:
381+
self._swa_update()
382+
383+
if (
384+
self.use_snapshot_ensemble
385+
and self.budget_tracker.max_epochs is not None
386+
and epoch > (self.budget_tracker.max_epochs - self.se_lastk)
387+
):
388+
self._se_update(epoch=epoch)
380389
return False
381390

382391
def _scheduler_step(

autoPyTorch/pipeline/image_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def _get_hyperparameter_search_space(self,
156156

157157
# Here we add custom code, like this with this
158158
# is not a valid configuration
159+
cs = self._add_forbidden_conditions(cs)
159160

160161
self.configuration_space = cs
161162
self.dataset_properties = dataset_properties

autoPyTorch/pipeline/tabular_classification.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

55
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
6-
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause
76

87
import numpy as np
98

@@ -261,33 +260,9 @@ def _get_hyperparameter_search_space(self,
261260
cs=cs, dataset_properties=dataset_properties,
262261
exclude=exclude, include=include, pipeline=self.steps)
263262

264-
# Here we add custom code, that is used to ensure valid configurations, For example
265-
# Learned Entity Embedding is only valid when encoder is one hot encoder
266-
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
267-
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
268-
if 'LearnedEntityEmbedding' in embeddings:
269-
encoders = cs.get_hyperparameter('encoder:__choice__').choices
270-
possible_default_embeddings = copy.copy(list(embeddings))
271-
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]
272-
273-
for encoder in encoders:
274-
if encoder == 'OneHotEncoder':
275-
continue
276-
while True:
277-
try:
278-
cs.add_forbidden_clause(ForbiddenAndConjunction(
279-
ForbiddenEqualsClause(cs.get_hyperparameter(
280-
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
281-
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
282-
))
283-
break
284-
except ValueError:
285-
# change the default and try again
286-
try:
287-
default = possible_default_embeddings.pop()
288-
except IndexError:
289-
raise ValueError("Cannot find a legal default configuration")
290-
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
263+
# Here we add custom code, like this with this
264+
# is not a valid configuration
265+
cs = self._add_forbidden_conditions(cs)
291266

292267
self.configuration_space = cs
293268
self.dataset_properties = dataset_properties

0 commit comments

Comments
 (0)