File tree 2 files changed +8
-6
lines changed
autoPyTorch/pipeline/components 2 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -155,8 +155,10 @@ def get_hyperparameter_search_space(
155
155
default = default_
156
156
break
157
157
158
- categorical_columns = dataset_properties ['categorical_columns' ] \
159
- if isinstance (dataset_properties ['categorical_columns' ], List ) else []
158
+ if isinstance (dataset_properties ['categorical_columns' ], list ):
159
+ categorical_columns = dataset_properties ['categorical_columns' ]
160
+ else :
161
+ categorical_columns = []
160
162
161
163
updates = self ._get_search_space_updates ()
162
164
if '__choice__' in updates .keys ():
Original file line number Diff line number Diff line change @@ -338,8 +338,8 @@ def _swa_update(self) -> None:
338
338
"""
339
339
perform swa model update
340
340
"""
341
- assert self .swa_model is not None , "SWA model can't be none when" \
342
- " stochastic weight averaging is enabled"
341
+ if self .swa_model is None :
342
+ raise ValueError ( "SWA model cannot be none when stochastic weight averaging is enabled")
343
343
self .swa_model .update_parameters (self .model )
344
344
self .swa_updated = True
345
345
@@ -350,8 +350,8 @@ def _se_update(self, epoch: int) -> None:
350
350
epoch (int):
351
351
current epoch
352
352
"""
353
- assert self .model_snapshots is not None , "model snapshots container can't be " \
354
- "none when snapshot ensembling is enabled"
353
+ if self .model_snapshots is None :
354
+ raise ValueError ( "model snapshots cannot be None when snapshot ensembling is enabled")
355
355
is_last_epoch = (epoch == self .budget_tracker .max_epochs )
356
356
if is_last_epoch and self .use_stochastic_weight_averaging :
357
357
model_copy = deepcopy (self .swa_model )
You can’t perform that action at this time.
0 commit comments