Skip to content

Commit 15e0b67

Browse files
committed
Added all models converters
1 parent 4581719 commit 15e0b67

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
2+
import sklearn
3+
from sklearn.tree import DecisionTreeClassifier
4+
from sklearn.neural_network import MLPClassifier
5+
from sklearn.linear_model import RidgeClassifier
6+
from sklearn.svm import SVC
7+
8+
import converters.decison_tree_converter as tree_cvt
9+
import converters.ridge_converter as ridge_cvt
10+
import converters.svm_converter as svm_cvt
11+
import converters.mlp_converter as mlp_cvt
12+
13+
def convert_model(clf):
14+
print("CONVERT_MODEL ", clf.__class__)
15+
if(clf.__class__ == sklearn.tree._classes.DecisionTreeClassifier):
16+
lConverter = tree_cvt.decision_tree_converter()
17+
return lConverter.convert_classifier(clf)
18+
if(clf.__class__ == sklearn.tree._classes.DecisionTreeRegressor):
19+
lConverter = tree_cvt.decision_tree_converter()
20+
return lConverter.convert_regressor(clf)
21+
if(clf.__class__ == sklearn.svm._classes.SVC):
22+
lConverter = svm_cvt.svm_converter()
23+
return lConverter.convert_classifier(clf)
24+
if(clf.__class__ == sklearn.svm._classes.SVR):
25+
lConverter = svm_cvt.svm_converter()
26+
return lConverter.convert_regressor(clf)
27+
if(clf.__class__ == sklearn.linear_model._ridge.RidgeClassifier):
28+
lConverter = ridge_cvt.ridge_converter()
29+
return lConverter.convert_classifier(clf)
30+
if(clf.__class__ == sklearn.linear_model._ridge.Ridge):
31+
lConverter = ridge_cvt.ridge_converter()
32+
return lConverter.convert_regressor(clf)
33+
if(clf.__class__ == sklearn.neural_network._multilayer_perceptron.MLPClassifier):
34+
lConverter = mlp_cvt.mlp_converter()
35+
return lConverter.convert_classifier(clf)
36+
if(clf.__class__ == sklearn.neural_network._multilayer_perceptron.MLPRegressor):
37+
lConverter = mlp_cvt.mlp_converter()
38+
return lConverter.convert_regressor(clf)
39+
40+
print("WARNING_CANNOT_CONVERT_MODEL ", clf.__class__)
41+
return None

converters/generic_converter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
3+
class json_converter:
4+
5+
def __init__(self):
6+
pass
7+

0 commit comments

Comments
 (0)