Skip to content

Commit 3ba9625

Browse files
tiny xor update
1 parent a5f5e7e commit 3ba9625

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

logistic_regression_class/logistic_xor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
ones = np.ones((N, 1))
2525

2626
# add a column of xy = x*y
27-
xy = np.matrix(X[:,0] * X[:,1]).T
28-
Xb = np.array(np.concatenate((ones, xy, X), axis=1))
27+
xy = (X[:,0] * X[:,1]).reshape(N, 1)
28+
Xb = np.concatenate((ones, xy, X), axis=1)
2929

3030
# randomly initialize the weights
3131
w = np.random.randn(D + 2)
@@ -60,7 +60,6 @@ def cross_entropy(T, Y):
6060
print e
6161

6262
# gradient descent weight udpate with regularization
63-
# w += learning_rate * ( np.dot((T - Y).T, Xb) - 0.01*w )
6463
w += learning_rate * ( Xb.T.dot(T - Y) - 0.01*w )
6564

6665
# recalculate Y

0 commit comments

Comments
 (0)