-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestNN.m
54 lines (43 loc) · 1.99 KB
/
TestNN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
% TestNN.m - Script to test the neural network and print performance metrics
% Load the trained network
load('trainedNetwork.mat', 'net');
% Assume testFeatures and testLabels are already loaded in the workspace
% If not, load them
if ~exist('testFeatures', 'var') || ~exist('testLabels', 'var')
load('preprocessedData.mat', 'testFeatures', 'testLabels');
end
% Preprocess the test labels (if necessary, convert to categorical and then to one-hot encoding)
if ~iscategorical(testLabels)
testLabelsCategorical = categorical(testLabels);
end
testLabelsCategorical = dummyvar(testLabelsCategorical)';
% Transpose testFeatures to match the network's expected input format
testFeatures = testFeatures';
% Test the network
testOutputs = net(testFeatures);
testPerformance = perform(net, testLabelsCategorical, testOutputs);
% Convert outputs to predicted labels
[~, predictedLabels] = max(testOutputs, [], 1);
predictedLabels = categorical(predictedLabels);
% Convert true labels from one-hot encoded to categorical for comparison
[~, actualLabels] = max(testLabelsCategorical, [], 1);
actualLabels = categorical(actualLabels);
% Calculate and print validation metrics
confusionMatrix = confusionmat(actualLabels, predictedLabels);
overallAccuracy = sum(diag(confusionMatrix)) / sum(confusionMatrix, 'all');
precision = diag(confusionMatrix) ./ sum(confusionMatrix, 2);
recall = diag(confusionMatrix) ./ sum(confusionMatrix, 1)';
f1Scores = 2 * (precision .* recall) ./ (precision + recall);
classAccuracies = diag(confusionMatrix) ./ sum(confusionMatrix, 1)';
% Print out the metrics
fprintf('Test Performance: %.4f\n', testPerformance);
fprintf('Overall Accuracy: %.4f\n', overallAccuracy);
for i = 1:length(classAccuracies)
fprintf('Class %d - Accuracy: %.4f, Precision: %.4f, Recall: %.4f,\n', i, classAccuracies(i), precision(i), recall(i));
end
% Plot the confusion matrix
figure;
plotconfusion(testLabelsCategorical, testOutputs);
% Optional: Plot ROC curve
figure;
plotroc(testLabelsCategorical, testOutputs);