Skip to content

Commit 6853a13

Browse files
committed
revert import statement to fix the patch
1 parent c1ac4a8 commit 6853a13

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
calculate_loss,
3838
get_metrics,
3939
)
40-
from autoPyTorch.pipeline.image_classification import ImageClassificationPipeline
41-
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
42-
from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline
43-
from autoPyTorch.pipeline.traditional_tabular_classification import TraditionalTabularClassificationPipeline
44-
from autoPyTorch.pipeline.traditional_tabular_regression import TraditionalTabularRegressionPipeline
40+
import autoPyTorch.pipeline.image_classification
41+
import autoPyTorch.pipeline.tabular_classification
42+
import autoPyTorch.pipeline.tabular_regression
43+
import autoPyTorch.pipeline.traditional_tabular_classification
44+
import autoPyTorch.pipeline.traditional_tabular_regression
4545
from autoPyTorch.utils.common import subsampler
4646
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
4747
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
@@ -80,8 +80,9 @@ def __init__(self, config: str,
8080
self.dataset_properties = dataset_properties
8181
self.random_state = random_state
8282
self.init_params = init_params
83-
self.pipeline = TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties,
84-
random_state=self.random_state)
83+
self.pipeline = autoPyTorch.pipeline.traditional_tabular_classification. \
84+
TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties,
85+
random_state=self.random_state)
8586
configuration_space = self.pipeline.get_hyperparameter_search_space()
8687
default_configuration = configuration_space.get_default_configuration().get_dictionary()
8788
default_configuration['model_trainer:tabular_traditional_model:traditional_learner'] = config
@@ -119,7 +120,8 @@ def get_pipeline_representation(self) -> Dict[str, str]:
119120

120121
@staticmethod
121122
def get_default_pipeline_options() -> Dict[str, Any]:
122-
return TraditionalTabularClassificationPipeline.get_default_pipeline_options()
123+
return autoPyTorch.pipeline.traditional_tabular_classification. \
124+
TraditionalTabularClassificationPipeline.get_default_pipeline_options()
123125

124126

125127
class MyTraditionalTabularRegressionPipeline(BaseEstimator):
@@ -148,8 +150,9 @@ def __init__(self, config: str,
148150
self.dataset_properties = dataset_properties
149151
self.random_state = random_state
150152
self.init_params = init_params
151-
self.pipeline = TraditionalTabularRegressionPipeline(dataset_properties=dataset_properties,
152-
random_state=self.random_state)
153+
self.pipeline = autoPyTorch.pipeline.traditional_tabular_regression. \
154+
TraditionalTabularRegressionPipeline(dataset_properties=dataset_properties,
155+
random_state=self.random_state)
153156
configuration_space = self.pipeline.get_hyperparameter_search_space()
154157
default_configuration = configuration_space.get_default_configuration().get_dictionary()
155158
default_configuration['model_trainer:tabular_traditional_model:traditional_learner'] = config
@@ -182,7 +185,8 @@ def get_pipeline_representation(self) -> Dict[str, str]:
182185

183186
@staticmethod
184187
def get_default_pipeline_options() -> Dict[str, Any]:
185-
return TraditionalTabularRegressionPipeline.get_default_pipeline_options()
188+
return autoPyTorch.pipeline.traditional_tabular_regression.\
189+
TraditionalTabularRegressionPipeline.get_default_pipeline_options()
186190

187191

188192
class DummyClassificationPipeline(DummyClassifier):
@@ -456,7 +460,7 @@ def __init__(self, backend: Backend,
456460
elif isinstance(self.configuration, str):
457461
self.pipeline_class = MyTraditionalTabularRegressionPipeline
458462
elif isinstance(self.configuration, Configuration):
459-
self.pipeline_class = TabularRegressionPipeline
463+
self.pipeline_class = autoPyTorch.pipeline.tabular_regression.TabularRegressionPipeline
460464
else:
461465
raise ValueError('task {} not available'.format(self.task_type))
462466
self.predict_function = self._predict_regression
@@ -470,9 +474,9 @@ def __init__(self, backend: Backend,
470474
raise ValueError("Only tabular tasks are currently supported with traditional methods")
471475
elif isinstance(self.configuration, Configuration):
472476
if self.task_type in TABULAR_TASKS:
473-
self.pipeline_class = TabularClassificationPipeline
477+
self.pipeline_class = autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline
474478
elif self.task_type in IMAGE_TASKS:
475-
self.pipeline_class = ImageClassificationPipeline
479+
self.pipeline_class = autoPyTorch.pipeline.image_classification.ImageClassificationPipeline
476480
else:
477481
raise ValueError('task {} not available'.format(self.task_type))
478482
self.predict_function = self._predict_proba

0 commit comments

Comments
 (0)