Skip to content

Commit 57bb577

Browse files
author
Kenzo
committed
add 2d and poly examples
1 parent cd98c55 commit 57bb577

File tree

6 files changed

+304
-0
lines changed

6 files changed

+304
-0
lines changed

linear_regression_class/data_2d.csv

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
17.9302012052,94.5205919533,320.259529602
2+
97.1446971852,69.5932819844,404.634471526
3+
81.7759007845,5.73764809688,181.485107741
4+
55.8543424175,70.3259016768,321.773638018
5+
49.3665499855,75.1140401571,322.465485583
6+
3.19270246506,29.2562988631,94.6188108954
7+
49.2007840582,86.1444385075,356.348092747
8+
21.882803904,46.8415051959,181.653769226
9+
79.5098627219,87.397355535,423.557743194
10+
88.1538874975,65.2056419279,369.229245443
11+
60.7438543399,99.9576339008,427.605803661
12+
67.4155819451,50.3683096094,292.471821553
13+
48.3181157719,99.1289531425,395.529811407
14+
28.8299719729,87.1849488537,319.031348455
15+
43.853742664,64.4736390798,287.4281441
16+
25.3136940868,83.5452942552,292.768908839
17+
10.807726675,45.6955685904,159.663307674
18+
98.365745882,82.6973935253,438.798963866
19+
29.1469099692,66.3651067611,250.986309034
20+
65.1003018959,33.3538834975,231.711507921
21+
24.6441134909,39.5400527406,163.398160832
22+
37.5598048804,1.34572784158,83.4801551365
23+
88.1645062351,95.1536625708,466.265805791
24+
13.8346208354,25.494048204,100.886430268
25+
64.4108437467,77.2598381277,365.641048062
26+
68.925991798,97.4536007955,426.140015493
27+
39.4884422375,50.8561281903,235.532389457
28+
52.4631776831,59.7765096878,283.291640318
29+
48.4847869797,66.9703542227,298.581440375
30+
8.06208781377,98.2426001448,309.234108782
31+
32.7318877149,18.8535355289,129.610139015
32+
11.6523787962,66.2645117418,224.150542091
33+
13.7303535349,70.4725091314,235.305665631
34+
8.18555176536,41.8519894198,153.484189455
35+
53.6098761481,94.560121637,394.939444253
36+
95.3686098922,47.2955069599,336.126738879
37+
87.3336092062,93.8039343271,449.363351987
38+
66.3576110919,81.8475512843,387.014816485
39+
19.7547175302,65.5233009197,240.389441595
40+
21.1334404504,47.4371819938,177.14828058
41+
22.37386481,25.9556275364,119.61125795
42+
93.9904040537,0.127890520058,196.716166657
43+
86.7201980995,18.413766785,236.260808358
44+
98.9983729904,60.2312656882,384.381344695
45+
3.59396564372,96.2522173231,293.237183045
46+
15.1023633732,92.5569035696,304.890882807
47+
97.8341407695,2.02390810433,201.293598424
48+
19.9382196923,46.7782734632,170.610093035
49+
30.3735111369,58.7775251616,242.373484227
50+
73.2928831485,67.6696277593,353.082991304
51+
52.2309008804,81.9024482531,348.725688557
52+
86.4295761137,66.5402275975,365.959970954
53+
93.4008021441,18.0752459352,235.472381633
54+
13.213460059,91.4888587787,300.606878255
55+
4.59346270394,46.3359315158,145.818745251
56+
15.6692915825,35.5437439971,138.880334695
57+
52.9593597724,68.7202096086,317.163707741
58+
56.8175212312,47.5727319202,254.903631301
59+
51.1335430813,78.0421674552,334.5843335
60+
7.86216471526,17.7290817759,69.3555888141
61+
54.6986037014,92.7445841407,386.859937239
62+
86.3990630133,41.8886945857,294.871713576
63+
11.9475060186,42.9613867358,156.754219545
64+
70.3584010648,83.7062345104,391.806135291
65+
29.0223663255,84.327783082,319.310462851
66+
42.7594799067,97.4933260814,376.291589478
67+
96.2156564389,25.8342825832,280.617043873
68+
53.2277276552,27.9055085669,194.430465068
69+
30.3609896715,0.939644215313,69.6488631767
70+
83.2775653897,73.1793485695,384.597184993
71+
30.1876924786,7.1465385989,89.5390083986
72+
11.7884184617,51.6977608445,181.550682802
73+
18.2924240117,61.9779760484,224.773382884
74+
96.7126676927,9.0291015125,219.567093701
75+
31.0127386947,78.2833824603,298.490216481
76+
11.3972607536,61.728693243,199.944044715
77+
17.3925557906,4.24114086314,43.9156923531
78+
72.1826937267,34.5390721676,256.068378234
79+
73.980020785,3.71649343826,159.372580639
80+
94.4930583509,88.41719702,447.132703621
81+
84.5628207335,20.2411621935,233.078829579
82+
51.7424739723,11.009747964,131.070179635
83+
53.7485903953,60.0251022958,298.814332704
84+
85.0508347563,95.7369969538,451.80352299
85+
46.7772504514,90.2022062395,368.366435952
86+
49.7584341737,52.834494361,254.706774007
87+
24.1192565017,42.1028107802,168.308432759
88+
27.2015764537,29.9787492892,146.34226
89+
7.00959616769,55.876058392,176.810148976
90+
97.6469496688,8.14762512696,219.16028041
91+
1.38298250945,84.9440869198,252.905652954
92+
22.3235303534,27.5150750402,127.570478844
93+
45.0454062325,93.5204022248,375.822340339
94+
40.1639914677,0.161699234822,80.3890193252
95+
53.182739793,8.17031616228,142.718183077
96+
46.4567791559,82.000170914,336.876154384
97+
77.1303006946,95.1887594549,438.460586127
98+
68.600607572,72.5711807231,355.900286918
99+
41.6938871165,69.241125973,284.834636711
100+
4.1426693978,52.2547263792,168.034400947

linear_regression_class/data_poly.csv

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
76.7007086033,663.797275569
2+
95.2735441552,1014.3622816
3+
73.0957232493,618.938826916
4+
46.9516354572,288.012877367
5+
33.3137480056,144.977555864
6+
58.8001284334,412.327812022
7+
86.4775958313,844.415014024
8+
26.1438291437,89.351909563
9+
97.6793058523,1053.20525097
10+
43.7453159626,240.908778383
11+
27.3960599569,115.936712225
12+
97.7709380144,1034.55054041
13+
0.477748646463,3.44508498982
14+
45.8230852085,254.058885722
15+
54.8058554165,353.273447055
16+
90.1526491199,926.945048811
17+
7.25055401942,30.7577367238
18+
45.32559552,241.735203861
19+
78.0299288822,692.797188617
20+
56.9992954919,382.995647161
21+
94.0058295504,978.718285869
22+
26.9031773559,98.2290062278
23+
23.0527128171,84.9785703587
24+
72.1620607036,584.521550996
25+
99.03856132,1082.25149734
26+
3.8567282135,5.37807424483
27+
29.1370660814,122.013630198
28+
86.1406581482,841.297543579
29+
30.6637414672,131.606512241
30+
45.7038082016,267.357505947
31+
20.8365297577,52.1846807158
32+
57.1870229314,388.795855138
33+
35.6458427285,155.657694842
34+
19.3869531878,87.1260781016
35+
21.9669393722,83.7232155716
36+
68.2754384479,530.711261039
37+
0.0499320826943,14.4928408734
38+
88.0160403117,843.24850174
39+
52.302680128,340.811945309
40+
30.8083679391,128.066959237
41+
36.7450495784,172.783502591
42+
32.497674786,147.228458255
43+
69.2984401147,538.180640028
44+
8.50775977301,18.1584427007
45+
70.340432818,559.410466793
46+
57.5800635973,384.884606047
47+
88.4806455787,875.860995014
48+
18.4605500858,58.9483005575
49+
55.8041981373,376.843030927
50+
16.4395710474,24.9459291843
51+
11.5900497701,32.6626606312
52+
46.8705089511,262.885902286
53+
84.33492545,798.336107587
54+
47.0879172063,278.115426472
55+
25.3595168427,104.171319469
56+
57.5279893076,411.579592637
57+
64.1988259571,471.645130755
58+
4.21211019461,-4.60107066464
59+
54.0943345308,363.848079875
60+
74.1390699605,634.232884203
61+
73.6494046203,614.61698055
62+
36.128161214,165.31378043
63+
52.0563603735,326.302161959
64+
47.1946296354,282.326194483
65+
38.4175210129,199.427772234
66+
54.6207713889,349.23680114
67+
64.7683580294,489.500982377
68+
84.4005467356,805.013871913
69+
21.4404800201,67.0955374565
70+
73.2615461079,611.057504623
71+
20.1698478817,69.446653099
72+
31.4337548498,138.906777891
73+
12.5179786376,48.9516001981
74+
3.61761195371,5.32627355217
75+
81.2515736807,731.3845794
76+
56.8424497509,378.289439943
77+
82.2665238066,748.343138324
78+
37.4722110019,175.748739041
79+
84.6946340015,800.047623475
80+
39.7685020685,203.056324686
81+
2.05616850915,6.23518914756
82+
97.1247091279,1042.6919584
83+
31.1291991206,130.113663552
84+
84.6126409599,790.864163252
85+
85.9561287143,827.455642687
86+
80.9652012247,738.33866824
87+
88.8979965529,893.81365673
88+
58.3641047494,402.882690162
89+
50.3024480011,314.117036852
90+
12.7419734701,12.9681174938
91+
77.320666256,687.654958048
92+
98.9646405094,1089.94766197
93+
56.3266253207,372.340305154
94+
79.8360680928,722.043849621
95+
84.2993047538,798.10497093
96+
45.3458596295,251.755172506
97+
3.13987798949,4.26261828882
98+
70.1062835334,550.923455067
99+
80.3106828908,728.06984768
100+
72.0680436903,581.130210853
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import numpy as np
2+
3+
N = 100
4+
w = np.array([2, 3])
5+
with open('data_2d.csv', 'w') as f:
6+
X = np.random.uniform(low=0, high=100, size=(N,2))
7+
Y = np.dot(X, w) + 1 + np.random.normal(scale=5, size=N)
8+
for i in xrange(N):
9+
f.write("%s,%s,%s\n" % (X[i,0], X[i,1], Y[i]))
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
3+
N = 100
4+
with open('data_poly.csv', 'w') as f:
5+
X = np.random.uniform(low=0, high=100, size=N)
6+
X2 = X*X
7+
Y = 0.1*X2 + X + 3 + np.random.normal(scale=10, size=N)
8+
for i in xrange(N):
9+
f.write("%s,%s\n" % (X[i], Y[i]))
10+

linear_regression_class/lr_2d.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
from mpl_toolkits.mplot3d import Axes3D
3+
import matplotlib.pyplot as plt
4+
5+
6+
# load the data
7+
X = []
8+
Y = []
9+
for line in open('data_2d.csv'):
10+
x1, x2, y = line.split(',')
11+
X.append([1, float(x1), float(x2)]) # add the bias term x0 = 1
12+
Y.append(float(y))
13+
14+
# let's turn X and Y into numpy arrays since that will be useful later
15+
X = np.array(X)
16+
Y = np.array(Y)
17+
18+
19+
# let's plot the data to see what it looks like
20+
fig = plt.figure()
21+
ax = fig.add_subplot(111, projection='3d')
22+
ax.scatter(X[:,0], X[:,1], Y)
23+
plt.show()
24+
25+
26+
# apply the equations we learned to calculate a and b
27+
# numpy has a special method for solving Ax = b
28+
# so we don't use x = inv(A)*b
29+
# note: the * operator does element-by-element multiplication in numpy
30+
# np.dot() does what we expect for matrix multiplication
31+
w = np.linalg.solve(np.dot(X.T, X), np.dot(X.T, Y))
32+
Yhat = np.dot(X, w)
33+
34+
35+
# determine how good the model is by computing the r-squared
36+
d1 = Y - Yhat
37+
d2 = Y - Y.mean()
38+
r2 = 1 - d1.dot(d1) / d2.dot(d2)
39+
print "the r-squared is:", r2

linear_regression_class/lr_poly.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
5+
# load the data
6+
X = []
7+
Y = []
8+
for line in open('data_poly.csv'):
9+
x, y = line.split(',')
10+
x = float(x)
11+
X.append([1, x, x*x]) # add the bias term x0 = 1
12+
Y.append(float(y))
13+
14+
# let's turn X and Y into numpy arrays since that will be useful later
15+
X = np.array(X)
16+
Y = np.array(Y)
17+
18+
19+
# let's plot the data to see what it looks like
20+
plt.scatter(X[:,1], Y)
21+
plt.show()
22+
23+
24+
# apply the equations we learned to calculate a and b
25+
# numpy has a special method for solving Ax = b
26+
# so we don't use x = inv(A)*b
27+
# note: the * operator does element-by-element multiplication in numpy
28+
# np.dot() does what we expect for matrix multiplication
29+
w = np.linalg.solve(np.dot(X.T, X), np.dot(X.T, Y))
30+
Yhat = np.dot(X, w)
31+
32+
33+
# let's plot everything together to make sure it worked
34+
plt.scatter(X[:,1], Y)
35+
plt.plot(sorted(X[:,1]), sorted(Yhat))
36+
# note: shortcut since monotonically increasing
37+
# x-axis values have to be in order since the points
38+
# are joined from one element to the next
39+
plt.show()
40+
41+
42+
# determine how good the model is by computing the r-squared
43+
d1 = Y - Yhat
44+
d2 = Y - Y.mean()
45+
r2 = 1 - d1.dot(d1) / d2.dot(d2)
46+
print "the r-squared is:", r2

0 commit comments

Comments
 (0)