Skip to content

Commit 8955996

Browse files
committed
fix flake and mypy
1 parent d3df40f commit 8955996

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/variance_thresholding/VarianceThreshold.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Any, Dict, Optional, Union
2+
23
import numpy as np
4+
35
from sklearn.feature_selection import VarianceThreshold as SklearnVarianceThreshold
4-
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
56

7+
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
68
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \
79
autoPyTorchTabularPreprocessingComponent
810

@@ -31,11 +33,12 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
3133
return X
3234

3335
@staticmethod
34-
def get_properties(dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None
36+
def get_properties(
37+
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None
3538
) -> Dict[str, Union[str, bool]]:
3639

3740
return {
3841
'shortname': 'Variance Threshold',
3942
'name': 'Variance Threshold (constant feature removal)',
4043
'handles_sparse': True,
41-
}
44+
}

test/test_pipeline/components/preprocessing/test_variance_thresholding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from sklearn.base import BaseEstimator
66
from sklearn.compose import make_column_transformer
77

8-
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding.VarianceThreshold import VarianceThreshold
8+
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \
9+
VarianceThreshold import VarianceThreshold
910

1011

1112
def test_variance_threshold():
@@ -38,8 +39,8 @@ def test_variance_threshold():
3839

3940
# make column transformer with returned encoder to fit on data
4041
column_transformer = make_column_transformer((variance_threshold,
41-
X['dataset_properties']['numerical_columns']),
42-
remainder='passthrough')
42+
X['dataset_properties']['numerical_columns']),
43+
remainder='passthrough')
4344
column_transformer = column_transformer.fit(X['X_train'])
4445
transformed = column_transformer.transform(data[test_indices])
4546

0 commit comments

Comments
 (0)