Skip to content

Commit 8bf6280

Browse files
committed
Enable learned embeddings, fix bug with non cyclic schedulers
1 parent 484ead4 commit 8bf6280

File tree

2 files changed

+86
-86
lines changed

2 files changed

+86
-86
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/__init__.py

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -148,71 +148,62 @@ def get_hyperparameter_search_space(
148148
if default is None:
149149
defaults = [
150150
'NoEmbedding',
151-
# 'LearnedEntityEmbedding',
151+
'LearnedEntityEmbedding',
152152
]
153153
for default_ in defaults:
154154
if default_ in available_embedding:
155155
default = default_
156156
break
157157

158-
# Restrict embedding to NoEmbedding until preprocessing is fixed
159-
embedding = CSH.CategoricalHyperparameter('__choice__',
160-
['NoEmbedding'],
161-
default_value=default)
158+
categorical_columns = dataset_properties['categorical_columns'] \
159+
if isinstance(dataset_properties['categorical_columns'], List) else []
160+
161+
updates = self._get_search_space_updates()
162+
if '__choice__' in updates.keys():
163+
choice_hyperparameter = updates['__choice__']
164+
if not set(choice_hyperparameter.value_range).issubset(available_embedding):
165+
raise ValueError("Expected given update for {} to have "
166+
"choices in {} got {}".format(self.__class__.__name__,
167+
available_embedding,
168+
choice_hyperparameter.value_range))
169+
if len(categorical_columns) == 0:
170+
assert len(choice_hyperparameter.value_range) == 1
171+
if 'NoEmbedding' not in choice_hyperparameter.value_range:
172+
raise ValueError("Provided {} in choices, however, the dataset "
173+
"is incompatible with it".format(choice_hyperparameter.value_range))
174+
embedding = CSH.CategoricalHyperparameter('__choice__',
175+
choice_hyperparameter.value_range,
176+
default_value=choice_hyperparameter.default_value)
177+
else:
178+
179+
if len(categorical_columns) == 0:
180+
default = 'NoEmbedding'
181+
if include is not None and default not in include:
182+
raise ValueError("Provided {} in include, however, the dataset "
183+
"is incompatible with it".format(include))
184+
embedding = CSH.CategoricalHyperparameter('__choice__',
185+
['NoEmbedding'],
186+
default_value=default)
187+
else:
188+
embedding = CSH.CategoricalHyperparameter('__choice__',
189+
list(available_embedding.keys()),
190+
default_value=default)
191+
162192
cs.add_hyperparameter(embedding)
193+
for name in embedding.choices:
194+
updates = self._get_search_space_updates(prefix=name)
195+
config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore
196+
**updates)
197+
parent_hyperparameter = {'parent': embedding, 'value': name}
198+
cs.add_configuration_space(
199+
name,
200+
config_space,
201+
parent_hyperparameter=parent_hyperparameter
202+
)
203+
163204
self.configuration_space_ = cs
164205
self.dataset_properties_ = dataset_properties
165206
return cs
166-
# categorical_columns = dataset_properties['categorical_columns'] \
167-
# if isinstance(dataset_properties['categorical_columns'], List) else []
168-
169-
# updates = self._get_search_space_updates()
170-
# if '__choice__' in updates.keys():
171-
# choice_hyperparameter = updates['__choice__']
172-
# if not set(choice_hyperparameter.value_range).issubset(available_embedding):
173-
# raise ValueError("Expected given update for {} to have "
174-
# "choices in {} got {}".format(self.__class__.__name__,
175-
# available_embedding,
176-
# choice_hyperparameter.value_range))
177-
# if len(categorical_columns) == 0:
178-
# assert len(choice_hyperparameter.value_range) == 1
179-
# if 'NoEmbedding' not in choice_hyperparameter.value_range:
180-
# raise ValueError("Provided {} in choices, however, the dataset "
181-
# "is incompatible with it".format(choice_hyperparameter.value_range))
182-
# embedding = CSH.CategoricalHyperparameter('__choice__',
183-
# choice_hyperparameter.value_range,
184-
# default_value=choice_hyperparameter.default_value)
185-
# else:
186-
187-
# if len(categorical_columns) == 0:
188-
# default = 'NoEmbedding'
189-
# if include is not None and default not in include:
190-
# raise ValueError("Provided {} in include, however, the dataset "
191-
# "is incompatible with it".format(include))
192-
# embedding = CSH.CategoricalHyperparameter('__choice__',
193-
# ['NoEmbedding'],
194-
# default_value=default)
195-
# else:
196-
# embedding = CSH.CategoricalHyperparameter('__choice__',
197-
# list(available_embedding.keys()),
198-
# default_value=default)
199-
200-
# cs.add_hyperparameter(embedding)
201-
# for name in embedding.choices:
202-
# updates = self._get_search_space_updates(prefix=name)
203-
# config_space = available_embedding[name].get_hyperparameter_search_space(
204-
# dataset_properties, # type: ignore
205-
# **updates)
206-
# parent_hyperparameter = {'parent': embedding, 'value': name}
207-
# cs.add_configuration_space(
208-
# name,
209-
# config_space,
210-
# parent_hyperparameter=parent_hyperparameter
211-
# )
212-
213-
# self.configuration_space_ = cs
214-
# self.dataset_properties_ = dataset_properties
215-
# return cs
216207

217208
def transform(self, X: np.ndarray) -> np.ndarray:
218209
assert self.choice is not None, "Cannot call transform before the object is initialized"

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
334334
"""
335335
pass
336336

337+
def _swa_update(self) -> None:
338+
"""
339+
perform swa model update
340+
"""
341+
assert self.swa_model is not None, "SWA model can't be none when" \
342+
" stochastic weight averaging is enabled"
343+
self.swa_model.update_parameters(self.model)
344+
self.swa_updated = True
345+
346+
def _se_update(self, epoch: int) -> None:
347+
"""
348+
Add latest model or swa_model to model snapshot ensemble
349+
Args:
350+
epoch (int):
351+
current epoch
352+
"""
353+
assert self.model_snapshots is not None, "model snapshots container can't be " \
354+
"none when snapshot ensembling is enabled"
355+
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
356+
if is_last_epoch and self.use_stochastic_weight_averaging:
357+
model_copy = deepcopy(self.swa_model)
358+
else:
359+
model_copy = deepcopy(self.model)
360+
361+
assert model_copy is not None
362+
model_copy.cpu()
363+
self.model_snapshots.append(model_copy)
364+
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
365+
337366
def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
338367
"""
339368
Optional place holder for AutoPytorch Extensions.
@@ -344,39 +373,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
344373
if X['is_cyclic_scheduler']:
345374
if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1:
346375
if self.use_stochastic_weight_averaging:
347-
assert self.swa_model is not None, "SWA model can't be none when" \
348-
" stochastic weight averaging is enabled"
349-
self.swa_model.update_parameters(self.model)
350-
self.swa_updated = True
376+
self._swa_update()
351377
if self.use_snapshot_ensemble:
352-
assert self.model_snapshots is not None, "model snapshots container can't be " \
353-
"none when snapshot ensembling is enabled"
354-
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
355-
if is_last_epoch and self.use_stochastic_weight_averaging:
356-
model_copy = deepcopy(self.swa_model)
357-
else:
358-
model_copy = deepcopy(self.model)
359-
360-
assert model_copy is not None
361-
model_copy.cpu()
362-
self.model_snapshots.append(model_copy)
363-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
378+
self._se_update(epoch=epoch)
364379
else:
365-
if epoch > self._budget_threshold:
366-
if self.use_stochastic_weight_averaging:
367-
assert self.swa_model is not None, "SWA model can't be none when" \
368-
" stochastic weight averaging is enabled"
369-
self.swa_model.update_parameters(self.model)
370-
self.swa_updated = True
371-
if self.use_snapshot_ensemble:
372-
assert self.model_snapshots is not None, "model snapshots container can't be " \
373-
"none when snapshot ensembling is enabled"
374-
model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \
375-
else deepcopy(self.model)
376-
assert model_copy is not None
377-
model_copy.cpu()
378-
self.model_snapshots.append(model_copy)
379-
self.model_snapshots = self.model_snapshots[-self.se_lastk:]
380+
if epoch > self._budget_threshold and self.use_stochastic_weight_averaging:
381+
self._swa_update()
382+
383+
if (
384+
self.use_snapshot_ensemble
385+
and self.budget_tracker.max_epochs is not None
386+
and epoch > (self.budget_tracker.max_epochs - self.se_lastk)
387+
):
388+
self._se_update(epoch=epoch)
380389
return False
381390

382391
def _scheduler_step(

0 commit comments

Comments
 (0)