-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGamaPipeline.py
47 lines (37 loc) · 1.54 KB
/
GamaPipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from enum import Enum
from typing import List, Tuple, Any, Union, Optional
import importlib
GamaPipelineTypeUnion = Union[
"sklearn.pipeline.Pipeline", # type: ignore # noqa: F821
"scikit_longitudinal.pipeline.LongitudinalPipeline", # type: ignore # noqa: F821
]
class GamaPipelineType(Enum):
ScikitLearn = ("sklearn.pipeline", "Pipeline")
ScikitLongitudinal = ("scikit_longitudinal.pipeline", "LongitudinalPipeline")
def import_pipeline_class(self):
module_name, class_name = self.value
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Could not import module {module_name}") from e
return getattr(module, class_name)
@property
def required_import(self) -> str:
module_name, class_name = self.value
if class_name != "Pipeline":
return f"from {module_name} import {class_name} as Pipeline"
return f"from {module_name} import {class_name}"
class GamaPipeline:
def __new__( # type: ignore
cls,
steps: List[Tuple[str, Any]],
pipeline_type: Optional[GamaPipelineType] = None,
*args,
**kwargs,
) -> GamaPipelineTypeUnion:
if steps is None or not steps:
raise ValueError("Pipeline steps cannot be None or empty")
if pipeline_type is None:
raise ValueError("Pipeline type cannot be None")
PipelineClass = pipeline_type.import_pipeline_class()
return PipelineClass(steps, *args, **kwargs)