Skip to content

Commit c7cc712

Browse files
committed
fix flake
1 parent 4d90706 commit c7cc712

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

autoPyTorch/api/base_task.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,10 @@ def fit_pipeline(
14911491
dataset_name=dataset_name
14921492
)
14931493

1494+
# dataset_name is created inside the constructor of BaseDataset
1495+
# we expect it to be not None. This is for mypy
1496+
assert dataset.dataset_name is not None
1497+
14941498
# TAE expects each configuration to have a config_id.
14951499
# For fitting a pipeline as it is not part of the
14961500
# search process, it makes sense to set it to 0
@@ -1506,9 +1510,6 @@ def fit_pipeline(
15061510
self._backend.save_datamanager(dataset)
15071511

15081512
if self._logger is None:
1509-
# dataset_name is created inside the constructor of BaseDataset
1510-
# we expect it to be not None. This is for mypy
1511-
assert dataset.dataset_name is not None
15121513
self._logger = self._get_logger(dataset.dataset_name)
15131514

15141515
include_components = self.include_components if include_components is None else include_components
@@ -1576,6 +1577,7 @@ def fit_pipeline(
15761577
)
15771578

15781579
fitted_pipeline = self._get_fitted_pipeline(
1580+
dataset_name=dataset.dataset_name,
15791581
pipeline_idx=run_info.config.config_id + tae.initial_num_run,
15801582
run_info=run_info,
15811583
run_value=run_value,
@@ -1588,11 +1590,16 @@ def fit_pipeline(
15881590

15891591
def _get_fitted_pipeline(
15901592
self,
1593+
dataset_name: str,
15911594
pipeline_idx: int,
15921595
run_info: RunInfo,
15931596
run_value: RunValue,
15941597
disable_file_output: List[Union[str, DisableFileOutputParameters]]
15951598
) -> Optional[BasePipeline]:
1599+
1600+
if self._logger is None:
1601+
self._logger = self._get_logger(str(dataset_name))
1602+
15961603
if run_value.status != StatusType.SUCCESS:
15971604
warnings.warn(f"Fitting pipeline failed with status: {run_value.status}"
15981605
f", additional_info: {run_value.additional_info}")
@@ -1606,7 +1613,7 @@ def _get_fitted_pipeline(
16061613
else:
16071614
load_function = self._backend.load_model_by_seed_and_id_and_budget
16081615

1609-
return load_function(
1616+
return load_function( # type: ignore[no-any-return]
16101617
seed=self.seed,
16111618
idx=pipeline_idx,
16121619
budget=float(run_info.budget),

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def file_output(
858858
self.backend.save_targets_ensemble(self.Y_optimization)
859859

860860
if getattr(self, 'pipelines', None) is not None:
861-
if self.pipelines[0] is not None and len(self.pipelines) > 0:
861+
if self.pipelines[0] is not None and len(self.pipelines) > 0: # type: ignore[index, arg-type]
862862
if 'pipelines' not in self.disable_file_output:
863863
if self.task_type in CLASSIFICATION_TASKS:
864864
pipelines = VotingClassifier(estimators=None, voting='soft', )

test/test_utils/test_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
This tests the functionality in autoPyTorch/utils/common.
33
"""
4+
from enum import Enum
5+
46
import pytest
57

68
from autoPyTorch.utils.common import autoPyTorchEnum
@@ -11,6 +13,10 @@ class SubEnum(autoPyTorchEnum):
1113
y = "y"
1214

1315

16+
class DummyEnum(Enum): # You need to move it on top
17+
x = "x"
18+
19+
1420
@pytest.mark.parametrize('iter',
1521
([SubEnum.x],
1622
["x"],
@@ -32,9 +38,6 @@ def test_autopytorch_enum(iter):
3238

3339
assert e in iter
3440

35-
class DummyEnum(Enum): # You need to move it on top
36-
x = "x"
37-
3841

3942
@pytest.mark.parametrize('iter',
4043
[[SubEnum.y],
@@ -67,5 +70,3 @@ def test_raise_errors_autopytorch_enum(others):
6770

6871
with pytest.raises(RuntimeError):
6972
SubEnum.x == others
70-
71-

0 commit comments

Comments
 (0)