Skip to content

Commit 63fcdc6

Browse files
supervised class regression
1 parent 74c67ce commit 63fcdc6

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

supervised_class/regression.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from sklearn.neighbors import KNeighborsRegressor
4+
from sklearn.tree import DecisionTreeRegressor
5+
6+
N = 200
7+
X = np.linspace(0, 10, N).reshape(N, 1)
8+
Y = np.sin(X)
9+
10+
Ntrain = 20
11+
idx = np.random.choice(N, Ntrain)
12+
Xtrain = X[idx]
13+
Ytrain = Y[idx]
14+
15+
knn = KNeighborsRegressor(n_neighbors=2, weights='distance')
16+
knn.fit(Xtrain, Ytrain)
17+
Yknn = knn.predict(X)
18+
19+
dt = DecisionTreeRegressor()
20+
dt.fit(Xtrain, Ytrain)
21+
Ydt = dt.predict(X)
22+
23+
plt.scatter(Xtrain, Ytrain) # show the training points
24+
plt.plot(X, Y) # show the original data
25+
plt.plot(X, Yknn, label='KNN')
26+
plt.plot(X, Ydt, label='Decision Tree')
27+
plt.legend()
28+
plt.show()
29+
30+

0 commit comments

Comments
 (0)