Skip to content

Commit 715d4c5

Browse files
visualize lda
1 parent 21b821c commit 715d4c5

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# visualizes the Bayes solution
2+
#
3+
# the notes for this class can be found at:
4+
# https://www.udemy.com/data-science-logistic-regression-in-python
5+
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
9+
N = 100
10+
D = 2
11+
12+
13+
X = np.random.randn(N,D)
14+
15+
# center the first 50 points at (-2,-2)
16+
X[:50,:] = X[:50,:] - 2*np.ones((50,D))
17+
18+
# center the last 50 points at (2, 2)
19+
X[50:,:] = X[50:,:] + 2*np.ones((50,D))
20+
21+
# labels: first 50 are 0, last 50 are 1
22+
T = np.array([0]*50 + [1]*50)
23+
24+
# add a column of ones
25+
ones = np.array([[1]*N]).T
26+
Xb = np.concatenate((ones, X), axis=1)
27+
28+
def sigmoid(z):
29+
return 1/(1 + np.exp(-z))
30+
31+
# get the closed-form solution
32+
w = np.array([0, 4, 4])
33+
34+
# calculate the model output
35+
z = Xb.dot(w)
36+
Y = sigmoid(z)
37+
38+
plt.scatter(X[:,0], X[:,1], c=T, s=100, alpha=0.5)
39+
40+
x_axis = np.linspace(-6, 6, 100)
41+
y_axis = -x_axis
42+
plt.plot(x_axis, y_axis)
43+
plt.show()

0 commit comments

Comments
 (0)