Skip to content

Commit 6bc3f06

Browse files
overfitting
1 parent cac4e0c commit 6bc3f06

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# notes for this course can be founda at:
2+
# https://www.udemy.com/data-science-linear-regression-in-python
3+
4+
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
# make up some data and plot it
9+
N = 100
10+
X = np.linspace(0, 6*np.pi, N)
11+
Y = np.sin(X)
12+
13+
plt.plot(X, Y)
14+
plt.show()
15+
16+
17+
def make_poly(X, deg):
18+
n = len(X)
19+
data = [np.ones(n)]
20+
for d in xrange(deg):
21+
data.append(X**(d+1))
22+
return np.vstack(data).T
23+
24+
25+
def fit(X, Y):
26+
return np.linalg.solve(X.T.dot(X), X.T.dot(Y))
27+
28+
29+
def fit_and_display(X, Y, sample, deg):
30+
N = len(X)
31+
train_idx = np.random.choice(N, sample)
32+
Xtrain = X[train_idx]
33+
Ytrain = Y[train_idx]
34+
35+
plt.scatter(Xtrain, Ytrain)
36+
plt.show()
37+
38+
# fit polynomial
39+
Xtrain_poly = make_poly(Xtrain, deg)
40+
w = fit(Xtrain_poly, Ytrain)
41+
42+
# display the polynomial
43+
X_poly = make_poly(X, deg)
44+
Y_hat = X_poly.dot(w)
45+
plt.plot(X, Y)
46+
plt.plot(X, Y_hat)
47+
plt.scatter(Xtrain, Ytrain)
48+
plt.title("deg = %d" % deg)
49+
plt.show()
50+
51+
for deg in (5, 6, 7, 8, 9):
52+
fit_and_display(X, Y, 10, deg)
53+
54+
55+
def get_mse(Y, Yhat):
56+
d = Y - Yhat
57+
return d.dot(d) / len(d)
58+
59+
60+
def plot_train_vs_test_curves(X, Y, sample=20, max_deg=20):
61+
N = len(X)
62+
train_idx = np.random.choice(N, sample)
63+
Xtrain = X[train_idx]
64+
Ytrain = Y[train_idx]
65+
66+
test_idx = [idx for idx in xrange(N) if idx not in train_idx]
67+
# test_idx = np.random.choice(N, sample)
68+
Xtest = X[test_idx]
69+
Ytest = Y[test_idx]
70+
71+
mse_trains = []
72+
mse_tests = []
73+
for deg in xrange(max_deg+1):
74+
Xtrain_poly = make_poly(Xtrain, deg)
75+
w = fit(Xtrain_poly, Ytrain)
76+
Yhat_train = Xtrain_poly.dot(w)
77+
mse_train = get_mse(Ytrain, Yhat_train)
78+
79+
Xtest_poly = make_poly(Xtest, deg)
80+
Yhat_test = Xtest_poly.dot(w)
81+
mse_test = get_mse(Ytest, Yhat_test)
82+
83+
mse_trains.append(mse_train)
84+
mse_tests.append(mse_test)
85+
86+
plt.plot(mse_trains, label="train mse")
87+
plt.plot(mse_tests, label="test mse")
88+
plt.legend()
89+
plt.show()
90+
91+
plt.plot(mse_trains, label="train mse")
92+
plt.legend()
93+
plt.show()
94+
95+
plot_train_vs_test_curves(X, Y)

0 commit comments

Comments
 (0)