Skip to content

Commit

Permalink
Update NNpredict.m
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammad Mehdi Samsami authored Apr 19, 2018
1 parent 643ed2d commit 2ea09bf
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions NNpredict.m
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
function p = NNpredict(Theta1, Theta2, X)

% This function outputs the predicted label of X given the trained weights
% of a neural network (Theta1, Theta2).

% Useful values
m = size(X, 1);
num_labels = size(Theta2, 1);
p = zeros(size(X, 1), 1);

% Add ones to the X data matrix
X = [ones(m, 1) X];
p = zeros(size(X, 1), 1);

newX = sigmoid(X*Theta1');
newX = [ones(m, 1) newX];
p = sigmoid(newX*Theta2');
[maxP, indP] = max(p, [], 2);
p = indP;
h1 = sigmoid([ones(m, 1) X] * Theta1');
h2 = sigmoid([ones(m, 1) h1] * Theta2');
[dummy, p] = max(h2, [], 2);

end

0 comments on commit 2ea09bf

Please sign in to comment.