-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckNNGradients.m
51 lines (39 loc) · 1.95 KB
/
checkNNGradients.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
function checkNNGradients(lambda)
% This function Creates a small neural network to check the backpropagation
% gradients, it will output the analytical gradients produced by our
% backpropagation code and the numerical gradients (computed using
% computeNumericalGradient). These two gradient computations should
% result in very similar values.
if ~exist('lambda', 'var') || isempty(lambda)
lambda = 0;
end
input_layer_size = 3;
hidden_layer_size = 5;
num_labels = 3;
m = 5;
% We generate some 'random' test data
Theta1 = debugInitializeWeights(hidden_layer_size, input_layer_size);
Theta2 = debugInitializeWeights(num_labels, hidden_layer_size);
% Reusing debugInitializeWeights to generate X
X = debugInitializeWeights(m, input_layer_size - 1);
y = 1 + mod(1:m, num_labels)';
% Unrolling parameters
nn_params = [Theta1(:) ; Theta2(:)];
% Short hand for cost function
costFunc = @(p) nnCostFunction(p, input_layer_size, hidden_layer_size, ...
num_labels, X, y, lambda);
[cost, grad] = costFunc(nn_params);
numgrad = computeNumericalGradient(costFunc, nn_params);
% Visually examine the two gradient computations. The two columns
% should be very similar.
disp([numgrad grad]);
fprintf(['The above two columns you get should be very similar.\n' ...
'(Left-Your Numerical Gradient, Right-Analytical Gradient)\n\n']);
% Evaluating the norm of the difference between two solutions.
% If we have a correct implementation, and assuming you used EPSILON = 0.0001
% in computeNumericalGradient.m, then diff below should be less than 1e-9
diff = norm(numgrad-grad)/norm(numgrad+grad);
fprintf(['If your backpropagation implementation is correct, then \n' ...
'the relative difference will be small (less than 1e-9). \n' ...
'\nRelative Difference: %g\n'], diff);
end