diff --git a/NNpredict.m b/NNpredict.m index 4bf3c99..183e530 100644 --- a/NNpredict.m +++ b/NNpredict.m @@ -1,17 +1,16 @@ function p = NNpredict(Theta1, Theta2, X) + % This function outputs the predicted label of X given the trained weights + % of a neural network (Theta1, Theta2). + % Useful values m = size(X, 1); num_labels = size(Theta2, 1); - p = zeros(size(X, 1), 1); - % Add ones to the X data matrix - X = [ones(m, 1) X]; + p = zeros(size(X, 1), 1); - newX = sigmoid(X*Theta1'); - newX = [ones(m, 1) newX]; - p = sigmoid(newX*Theta2'); - [maxP, indP] = max(p, [], 2); - p = indP; + h1 = sigmoid([ones(m, 1) X] * Theta1'); + h2 = sigmoid([ones(m, 1) h1] * Theta2'); + [dummy, p] = max(h2, [], 2); end