forked from nut-tree/opencv4nodejs
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmachineLearningOCR.ts
123 lines (105 loc) · 3.38 KB
/
machineLearningOCR.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import * as fs from 'fs';
import * as cv from '../../';
import {
lccs,
centerLetterInImage,
saveConfusionMatrix
} from './OCRTools';
const trainDataPath = '../../data/ocr/traindata';
const testDataPath = '../../data/ocr/testdata';
const outPath = '../../data/ocr';
const SVMFile = 'lcletters.xml';
const hog = new cv.HOGDescriptor({
winSize: new cv.Size(40, 40),
blockSize: new cv.Size(20, 20),
blockStride: new cv.Size(10, 10),
cellSize: new cv.Size(10, 10),
L2HysThreshold: 0.2,
nbins: 9,
gammaCorrection: true,
signedGradient: true
});
const svm = new cv.SVM({
kernelType: cv.ml.SVM.RBF,
c: 12.5,
gamma: 0.50625
});
const computeHOGDescriptorFromImage = (img: cv.Mat, isIorJ: boolean) => {
let im = img;
if (im.rows !== 40 || im.cols !== 40) {
im = im.resize(40, 40);
}
// center the letter
im = centerLetterInImage(img, isIorJ);
if (!img) {
return null;
}
return hog.compute(im);
};
const trainSVM = (trainDataFiles: string[][], isAuto: boolean = false) => {
// make hog features of trainingData and label it
console.log('make features');
const samples: number[][] = [];
const labels: number[] = [];
trainDataFiles.forEach((files, label) => {
files.forEach((file) => {
const img = cv.imread(file);
const isIorJ = label === 8 || label === 9;
const desc = computeHOGDescriptorFromImage(img, isIorJ);
if (!desc) {
return;
}
samples.push(desc);
labels.push(label);
});
});
// train the SVM
console.log('training');
const trainData = new cv.TrainData(
new cv.Mat(samples, cv.CV_32F),
cv.ml.ROW_SAMPLE,
new cv.Mat([labels], cv.CV_32S)
);
svm[isAuto ? 'trainAuto' : 'train'](trainData);
};
const data = lccs.map((letter) => {
const trainDataDir = `${trainDataPath}/${letter}`;
const testDataDir = `${testDataPath}/${letter}`;
const train = fs.readdirSync(trainDataDir).map(file => `${trainDataDir}/${file}`);
const test = fs.readdirSync(testDataDir).map(file => `${testDataDir}/${file}`);
return ({ train, test });
});
const trainDataFiles = data.map(classData => classData.train);
const testDataFiles = data.map(classData => classData.test);
const numTrainImagesPerClass = trainDataFiles[0].length;
const numTestImagesPerClass = testDataFiles[0].length;
console.log('train data per class:', numTrainImagesPerClass);
console.log('test data per class:', numTestImagesPerClass);
trainSVM(trainDataFiles, false);
svm.save(`${outPath}/${SVMFile}`);
svm.load(`${outPath}/${SVMFile}`);
// compute prediction error for each letter
const errs = Array(26).fill(0);
testDataFiles.forEach((files, label) => {
files.forEach((file) => {
const img = cv.imread(file);
const isIorJ = label === 8 || label === 9;
const desc = computeHOGDescriptorFromImage(img, isIorJ);
if (!desc) {
throw new Error(`Computing HOG descriptor failed for file: ${file}`);
}
const predictedLabel = svm.predict(desc);
if (label !== predictedLabel) {
errs[label] += 1;
}
});
});
console.log('prediction result:');
errs.forEach((err, l) => console.log(lccs[l], err, 1 - (err / numTestImagesPerClass)));
console.log('average: ', 1 - (errs.reduce((e1, e2) => e1 + e2) / (lccs.length * numTestImagesPerClass)));
saveConfusionMatrix(
testDataFiles,
(img, isIorJ) => svm.predict(computeHOGDescriptorFromImage(img, isIorJ)),
numTestImagesPerClass,
`${outPath}/confusionmatrix.csv`
);