-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
104 lines (70 loc) · 2.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
st.title('Find a Suitable Machine learning model for a dataset')
st.write("""
Explore different datasets and find out which classifier model to use for better accuracy. """)
dataset_name=st.sidebar.selectbox("Select Dataset",("Iris","Breast cancer","Wine dataset"))
classifier_name=st.sidebar.selectbox('Select Model',('KNN','SVM','Random forest'))
def get_dataset(dataset_name):
if dataset_name=='Iris':
data=datasets.load_iris()
elif dataset_name=='Breast cancer':
data=datasets.load_breast_cancer()
else:
data=datasets.load_wine()
X=data.data
Y=data.target
return X,Y
X,Y=get_dataset(dataset_name)
st.write('Shape of the dataset ',X.shape)
st.write('Number of classes ',len(np.unique(Y)))
def add_parameter_ui(classifier_name):
params=dict()
if classifier_name=='KNN':
K=st.sidebar.slider('K',1,15)
params["K"]=K
elif classifier_name=='SVM':
C=st.sidebar.slider("C",0.01,1.0)
params["C"]=C
else:
max_depth=st.sidebar.slider("max_depth",2,15)
n_estimators=st.sidebar.slider("n_estimators",2,15)
params["max_depth"]=max_depth
params["n_estimators"]=n_estimators
return params
params=add_parameter_ui(classifier_name)
def get_classifier(params,classifier_name):
if classifier_name=='KNN':
clf=KNeighborsClassifier(n_neighbors=params['K'])
elif classifier_name=='SVM':
clf=SVC(C=params['C'])
else:
clf=RandomForestClassifier(n_estimators=params['n_estimators'],max_depth=params['max_depth'],random_state=1234)
return clf
clf=get_classifier(params,classifier_name)
#Classification
X_train,X_test,Y_train,Y_test=train_test_split(X,Y,test_size=0.2,random_state=1234)
clf.fit(X_train,Y_train)
y_pred=clf.predict(X_test)
acc=accuracy_score(y_pred,Y_test)
st.write(f'Classifier={classifier_name}')
st.write(f'accuracy={acc}')
pca=PCA(2)
X_projected=pca.fit_transform(X)
x1=X_projected[:,0]
x2=X_projected[:,1]
fig=plt.figure()
plt.scatter(x1,x2,c=Y,alpha=0.8,cmap='viridis')
plt.xlabel('Principle Component 1')
plt.ylabel('Principle Component 2')
plt.colorbar()
#plt.show()
st.pyplot(fig)