Skip to content

Commit 0cc6b46

Browse files
Apply suggestions from code review
Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent cda66c8 commit 0cc6b46

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,10 @@ def get_hyperparameter_search_space(
155155
default = default_
156156
break
157157

158-
categorical_columns = dataset_properties['categorical_columns'] \
159-
if isinstance(dataset_properties['categorical_columns'], List) else []
158+
if isinstance(dataset_properties['categorical_columns'], list):
159+
categorical_columns = dataset_properties['categorical_columns']
160+
else:
161+
categorical_columns = []
160162

161163
updates = self._get_search_space_updates()
162164
if '__choice__' in updates.keys():

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def _swa_update(self) -> None:
338338
"""
339339
perform swa model update
340340
"""
341-
assert self.swa_model is not None, "SWA model can't be none when" \
342-
" stochastic weight averaging is enabled"
341+
if self.swa_model is None:
342+
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
343343
self.swa_model.update_parameters(self.model)
344344
self.swa_updated = True
345345

@@ -350,8 +350,8 @@ def _se_update(self, epoch: int) -> None:
350350
epoch (int):
351351
current epoch
352352
"""
353-
assert self.model_snapshots is not None, "model snapshots container can't be " \
354-
"none when snapshot ensembling is enabled"
353+
if self.model_snapshots is None:
354+
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled")
355355
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
356356
if is_last_epoch and self.use_stochastic_weight_averaging:
357357
model_copy = deepcopy(self.swa_model)

0 commit comments

Comments
 (0)