Skip to content

Commit 366bede

Browse files
[FIX] SWA and SE with non cyclic schedulers (#395)
* Enable learned embeddings, fix bug with non cyclic schedulers * add forbidden condition cyclic lr * refactor base_pipeline forbidden conditions * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent c1fffa1 commit 366bede

File tree

6 files changed

+156
-142
lines changed

6 files changed

+156
-142
lines changed

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from copy import copy
12
import warnings
23
from abc import ABCMeta
34
from collections import Counter
45
from typing import Any, Dict, List, Optional, Tuple, Union
56

67
from ConfigSpace import Configuration
78
from ConfigSpace.configuration_space import ConfigurationSpace
9+
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause
810

911
import numpy as np
1012

@@ -295,6 +297,67 @@ def _get_hyperparameter_search_space(self,
295297
"""
296298
raise NotImplementedError()
297299

300+
def _add_forbidden_conditions(self, cs):
301+
"""
302+
Add forbidden conditions to ensure valid configurations.
303+
Currently, Learned Entity Embedding is only valid when encoder is one hot encoder
304+
and CyclicLR is disabled when using stochastic weight averaging and snapshot
305+
ensembling.
306+
307+
Args:
308+
cs (ConfigurationSpace):
309+
Configuration space to which forbidden conditions are added.
310+
311+
"""
312+
313+
# Learned Entity Embedding is only valid when encoder is one hot encoder
314+
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
315+
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
316+
if 'LearnedEntityEmbedding' in embeddings:
317+
encoders = cs.get_hyperparameter('encoder:__choice__').choices
318+
possible_default_embeddings = copy(list(embeddings))
319+
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]
320+
321+
for encoder in encoders:
322+
if encoder == 'OneHotEncoder':
323+
continue
324+
while True:
325+
try:
326+
cs.add_forbidden_clause(ForbiddenAndConjunction(
327+
ForbiddenEqualsClause(cs.get_hyperparameter(
328+
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
329+
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
330+
))
331+
break
332+
except ValueError:
333+
# change the default and try again
334+
try:
335+
default = possible_default_embeddings.pop()
336+
except IndexError:
337+
raise ValueError("Cannot find a legal default configuration")
338+
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
339+
340+
# Disable CyclicLR until todo is completed.
341+
if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys():
342+
trainers = cs.get_hyperparameter('trainer:__choice__').choices
343+
for trainer in trainers:
344+
available_schedulers = cs.get_hyperparameter('lr_scheduler:__choice__').choices
345+
# TODO: update cyclic lr to use n_restarts and adjust according to batch size
346+
cyclic_lr_name = 'CyclicLR'
347+
if cyclic_lr_name in available_schedulers:
348+
# disable snapshot ensembles and stochastic weight averaging
349+
cs.add_forbidden_clause(ForbiddenAndConjunction(
350+
ForbiddenEqualsClause(cs.get_hyperparameter(
351+
f'trainer:{trainer}:use_snapshot_ensemble'), True),
352+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
353+
))
354+
cs.add_forbidden_clause(ForbiddenAndConjunction(
355+
ForbiddenEqualsClause(cs.get_hyperparameter(
356+
f'trainer:{trainer}:use_stochastic_weight_averaging'), True),
357+
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
358+
))
359+
return cs
360+
298361
def __repr__(self) -> str:
299362
"""Retrieves a str representation of the current pipeline
300363

autoPyTorch/pipeline/components/setup/network_embedding/__init__.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -146,71 +146,64 @@ def get_hyperparameter_search_space(
146146
if default is None:
147147
defaults = [
148148
'NoEmbedding',
149-
# 'LearnedEntityEmbedding',
149+
'LearnedEntityEmbedding',
150150
]
151151
for default_ in defaults:
152152
if default_ in available_embedding:
153153
default = default_
154154
break
155155

156-
# Restrict embedding to NoEmbedding until preprocessing is fixed
157-
embedding = CSH.CategoricalHyperparameter('__choice__',
158-
['NoEmbedding'],
159-
default_value=default)
156+
if isinstance(dataset_properties['categorical_columns'], list):
157+
categorical_columns = dataset_properties['categorical_columns']
158+
else:
159+
categorical_columns = []
160+
161+
updates = self._get_search_space_updates()
162+
if '__choice__' in updates.keys():
163+
choice_hyperparameter = updates['__choice__']
164+
if not set(choice_hyperparameter.value_range).issubset(available_embedding):
165+
raise ValueError("Expected given update for {} to have "
166+
"choices in {} got {}".format(self.__class__.__name__,
167+
available_embedding,
168+
choice_hyperparameter.value_range))
169+
if len(categorical_columns) == 0:
170+
assert len(choice_hyperparameter.value_range) == 1
171+
if 'NoEmbedding' not in choice_hyperparameter.value_range:
172+
raise ValueError("Provided {} in choices, however, the dataset "
173+
"is incompatible with it".format(choice_hyperparameter.value_range))
174+
embedding = CSH.CategoricalHyperparameter('__choice__',
175+
choice_hyperparameter.value_range,
176+
default_value=choice_hyperparameter.default_value)
177+
else:
178+
179+
if len(categorical_columns) == 0:
180+
default = 'NoEmbedding'
181+
if include is not None and default not in include:
182+
raise ValueError("Provided {} in include, however, the dataset "
183+
"is incompatible with it".format(include))
184+
embedding = CSH.CategoricalHyperparameter('__choice__',
185+
['NoEmbedding'],
186+
default_value=default)
187+
else:
188+
embedding = CSH.CategoricalHyperparameter('__choice__',
189+
list(available_embedding.keys()),
190+
default_value=default)
191+
160192
cs.add_hyperparameter(embedding)
193+
for name in embedding.choices:
194+
updates = self._get_search_space_updates(prefix=name)
195+
config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore
196+
**updates)
197+
parent_hyperparameter = {'parent': embedding, 'value': name}
198+
cs.add_configuration_space(
199+
name,
200+
config_space,
201+
parent_hyperparameter=parent_hyperparameter
202+
)
203+
161204
self.configuration_space_ = cs
162205
self.dataset_properties_ = dataset_properties
163206
return cs
164-
# categorical_columns = dataset_properties['categorical_columns'] \
165-
# if isinstance(dataset_properties['categorical_columns'], List) else []
166-
167-
# updates = self._get_search_space_updates()
168-
# if '__choice__' in updates.keys():
169-
# choice_hyperparameter = updates['__choice__']
170-
# if not set(choice_hyperparameter.value_range).issubset(available_embedding):
171-
# raise ValueError("Expected given update for {} to have "
172-
# "choices in {} got {}".format(self.__class__.__name__,
173-
# available_embedding,
174-
# choice_hyperparameter.value_range))
175-
# if len(categorical_columns) == 0:
176-
# assert len(choice_hyperparameter.value_range) == 1
177-
# if 'NoEmbedding' not in choice_hyperparameter.value_range:
178-
# raise ValueError("Provided {} in choices, however, the dataset "
179-
# "is incompatible with it".format(choice_hyperparameter.value_range))
180-
# embedding = CSH.CategoricalHyperparameter('__choice__',
181-
# choice_hyperparameter.value_range,
182-
# default_value=choice_hyperparameter.default_value)
183-
# else:
184-
185-
# if len(categorical_columns) == 0:
186-
# default = 'NoEmbedding'
187-
# if include is not None and default not in include:
188-
# raise ValueError("Provided {} in include, however, the dataset "
189-
# "is incompatible with it".format(include))
190-
# embedding = CSH.CategoricalHyperparameter('__choice__',
191-
# ['NoEmbedding'],
192-
# default_value=default)
193-
# else:
194-
# embedding = CSH.CategoricalHyperparameter('__choice__',
195-
# list(available_embedding.keys()),
196-
# default_value=default)
197-
198-
# cs.add_hyperparameter(embedding)
199-
# for name in embedding.choices:
200-
# updates = self._get_search_space_updates(prefix=name)
201-
# config_space = available_embedding[name].get_hyperparameter_search_space(
202-
# dataset_properties, # type: ignore
203-
# **updates)
204-
# parent_hyperparameter = {'parent': embedding, 'value': name}
205-
# cs.add_configuration_space(
206-
# name,
207-
# config_space,
208-
# parent_hyperparameter=parent_hyperparameter
209-
# )
210-
211-
# self.configuration_space_ = cs
212-
# self.dataset_properties_ = dataset_properties
213-
# return cs
214207

215208
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
216209
assert self.choice is not None, "Cannot call transform before the object is initialized"

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
342342
"""
343343
pass
344344

345+
def _swa_update(self) -> None:
346+
"""
347+
perform swa model update
348+
"""
349+
if self.swa_model is None:
350+
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
351+
self.swa_model.update_parameters(self.model)
352+
self.swa_updated = True
353+
354+
def _se_update(self, epoch: int) -> None:
355+
"""
356+
Add latest model or swa_model to model snapshot ensemble
357+
Args:
358+
epoch (int):
359+
current epoch
360+
"""
361+
if self.model_snapshots is None:
362+
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled")
363+
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
364+
if is_last_epoch and self.use_stochastic_weight_averaging:
365+
model_copy = deepcopy(self.swa_model)
366+
else:
367+
model_copy = deepcopy(self.model)
368+
369+
assert model_copy is not None
370+
model_copy.cpu()
371+
self.model_snapshots.append(model_copy)
372+
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
373+
345374
def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
346375
"""
347376
Optional place holder for AutoPytorch Extensions.
@@ -352,39 +381,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
352381
if X['is_cyclic_scheduler']:
353382
if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1:
354383
if self.use_stochastic_weight_averaging:
355-
assert self.swa_model is not None, "SWA model can't be none when" \
356-
" stochastic weight averaging is enabled"
357-
self.swa_model.update_parameters(self.model)
358-
self.swa_updated = True
384+
self._swa_update()
359385
if self.use_snapshot_ensemble:
360-
assert self.model_snapshots is not None, "model snapshots container can't be " \
361-
"none when snapshot ensembling is enabled"
362-
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
363-
if is_last_epoch and self.use_stochastic_weight_averaging:
364-
model_copy = deepcopy(self.swa_model)
365-
else:
366-
model_copy = deepcopy(self.model)
367-
368-
assert model_copy is not None
369-
model_copy.cpu()
370-
self.model_snapshots.append(model_copy)
371-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
386+
self._se_update(epoch=epoch)
372387
else:
373-
if epoch > self._budget_threshold:
374-
if self.use_stochastic_weight_averaging:
375-
assert self.swa_model is not None, "SWA model can't be none when" \
376-
" stochastic weight averaging is enabled"
377-
self.swa_model.update_parameters(self.model)
378-
self.swa_updated = True
379-
if self.use_snapshot_ensemble:
380-
assert self.model_snapshots is not None, "model snapshots container can't be " \
381-
"none when snapshot ensembling is enabled"
382-
model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \
383-
else deepcopy(self.model)
384-
assert model_copy is not None
385-
model_copy.cpu()
386-
self.model_snapshots.append(model_copy)
387-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
388+
if epoch > self._budget_threshold and self.use_stochastic_weight_averaging:
389+
self._swa_update()
390+
391+
if (
392+
self.use_snapshot_ensemble
393+
and self.budget_tracker.max_epochs is not None
394+
and epoch > (self.budget_tracker.max_epochs - self.se_lastk)
395+
):
396+
self._se_update(epoch=epoch)
388397
return False
389398

390399
def _scheduler_step(

autoPyTorch/pipeline/image_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def _get_hyperparameter_search_space(self,
156156

157157
# Here we add custom code, like this with this
158158
# is not a valid configuration
159+
cs = self._add_forbidden_conditions(cs)
159160

160161
self.configuration_space = cs
161162
self.dataset_properties = dataset_properties

autoPyTorch/pipeline/tabular_classification.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

55
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
6-
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause
76

87
import numpy as np
98

@@ -261,33 +260,9 @@ def _get_hyperparameter_search_space(self,
261260
cs=cs, dataset_properties=dataset_properties,
262261
exclude=exclude, include=include, pipeline=self.steps)
263262

264-
# Here we add custom code, that is used to ensure valid configurations, For example
265-
# Learned Entity Embedding is only valid when encoder is one hot encoder
266-
if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys():
267-
embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices
268-
if 'LearnedEntityEmbedding' in embeddings:
269-
encoders = cs.get_hyperparameter('encoder:__choice__').choices
270-
possible_default_embeddings = copy.copy(list(embeddings))
271-
del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')]
272-
273-
for encoder in encoders:
274-
if encoder == 'OneHotEncoder':
275-
continue
276-
while True:
277-
try:
278-
cs.add_forbidden_clause(ForbiddenAndConjunction(
279-
ForbiddenEqualsClause(cs.get_hyperparameter(
280-
'network_embedding:__choice__'), 'LearnedEntityEmbedding'),
281-
ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder)
282-
))
283-
break
284-
except ValueError:
285-
# change the default and try again
286-
try:
287-
default = possible_default_embeddings.pop()
288-
except IndexError:
289-
raise ValueError("Cannot find a legal default configuration")
290-
cs.get_hyperparameter('network_embedding:__choice__').default_value = default
263+
# Here we add custom code, like this with this
264+
# is not a valid configuration
265+
cs = self._add_forbidden_conditions(cs)
291266

292267
self.configuration_space = cs
293268
self.dataset_properties = dataset_properties

0 commit comments

Comments
 (0)