forked from arthenica/tesseract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmastertrainer_test.cc
308 lines (282 loc) · 12.2 KB
/
mastertrainer_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Although this is a trivial-looking test, it exercises a lot of code:
// SampleIterator has to correctly iterate over the correct characters, or
// it will fail.
// The canonical and cloud features computed by TrainingSampleSet need to
// be correct, along with the distance caches, organizing samples by font
// and class, indexing of features, distance calculations.
// IntFeatureDist has to work, or the canonical samples won't work.
// Mastertrainer has ability to read tr files and set itself up tested.
// Finally the serialize/deserialize test ensures that MasterTrainer,
// TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
// enough to reproduce the same results.
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/numbers.h" // for safe_strto32
#include "absl/strings/str_split.h" // for absl::StrSplit
#include "include_gunit.h"
#include "genericvector.h"
#include "log.h" // for LOG
#include "unicharset.h"
#include "errorcounter.h"
#include "mastertrainer.h"
#include "shapeclassifier.h"
#include "shapetable.h"
#include "trainingsample.h"
#include "commontraining.h"
#include "tessopt.h" // tessoptind
// Specs of the MockClassifier.
static const int kNumTopNErrs = 10;
static const int kNumTop2Errs = kNumTopNErrs + 20;
static const int kNumTop1Errs = kNumTop2Errs + 30;
static const int kNumTopTopErrs = kNumTop1Errs + 25;
static const int kNumNonReject = 1000;
static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
// The total number of answers is given by the number of non-rejects plus
// all the multiple answers.
static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
(kNumTop1Errs - kNumTop2Errs) +
(kNumTopTopErrs - kNumTop1Errs);
#ifndef DISABLED_LEGACY_ENGINE
static bool safe_strto32(const std::string& str, int* pResult)
{
long n = strtol(str.c_str(), nullptr, 0);
*pResult = n;
return true;
}
#endif
namespace tesseract {
// Mock ShapeClassifier that cheats by looking at the correct answer, and
// creates a specific pattern of errors that can be tested.
class MockClassifier : public ShapeClassifier {
public:
explicit MockClassifier(ShapeTable* shape_table)
: shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
// Add a false font answer to the shape table. We pick a random unichar_id,
// add a new shape for it with a false font. Font must actually exist in
// the font table, but not match anything in the first 1000 samples.
false_unichar_id_ = 67;
false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
}
virtual ~MockClassifier() {}
// Classifies the given [training] sample, writing to results.
// If debug is non-zero, then various degrees of classifier dependent debug
// information is provided.
// If keep_this (a shape index) is >= 0, then the results should always
// contain keep_this, and (if possible) anything of intermediate confidence.
// The return value is the number of classes saved in results.
virtual int ClassifySample(const TrainingSample& sample, Pix* page_pix,
int debug, UNICHAR_ID keep_this,
GenericVector<ShapeRating>* results) {
results->clear();
// Everything except the first kNumNonReject is a reject.
if (++num_done_ > kNumNonReject) return 0;
int class_id = sample.class_id();
int font_id = sample.font_id();
int shape_id = shape_table_->FindShape(class_id, font_id);
// Get ids of some wrong answers.
int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
if (num_done_ <= kNumTopNErrs) {
// The first kNumTopNErrs are top-n errors.
results->push_back(ShapeRating(wrong_id1, 1.0f));
} else if (num_done_ <= kNumTop2Errs) {
// The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
results->push_back(ShapeRating(wrong_id1, 1.0f));
results->push_back(ShapeRating(wrong_id2, 0.875f));
results->push_back(ShapeRating(shape_id, 0.75f));
} else if (num_done_ <= kNumTop1Errs) {
// The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
results->push_back(ShapeRating(wrong_id1, 1.0f));
results->push_back(ShapeRating(shape_id, 0.8f));
} else if (num_done_ <= kNumTopTopErrs) {
// The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
// is not correct, but do not count as a top-1 error because the rating
// is close enough to the top answer.
results->push_back(ShapeRating(wrong_id1, 1.0f));
results->push_back(ShapeRating(shape_id, 0.99f));
} else if (!done_bad_font_ && class_id == false_unichar_id_) {
// There is a single character with a bad font.
results->push_back(ShapeRating(false_shape_, 1.0f));
done_bad_font_ = true;
} else {
// Everything else is correct.
results->push_back(ShapeRating(shape_id, 1.0f));
}
return results->size();
}
// Provides access to the ShapeTable that this classifier works with.
virtual const ShapeTable* GetShapeTable() const { return shape_table_; }
private:
// Borrowed pointer to the ShapeTable.
ShapeTable* shape_table_;
// Unichar_id of a random character that occurs after the first 60 samples.
int false_unichar_id_;
// Shape index of prepared false answer for false_unichar_id.
int false_shape_;
// The number of classifications we have processed.
int num_done_;
// True after the false font has been emitted.
bool done_bad_font_;
};
const double kMin1lDistance = 0.25;
// The fixture for testing Tesseract.
class MasterTrainerTest : public testing::Test {
#ifndef DISABLED_LEGACY_ENGINE
protected:
void SetUp() {
std::locale::global(std::locale(""));
file::MakeTmpdir();
}
std::string TestDataNameToPath(const std::string& name) {
return file::JoinPath(TESTING_DIR, name);
}
std::string TmpNameToPath(const std::string& name) {
return file::JoinPath(FLAGS_test_tmpdir, name);
}
MasterTrainerTest() {
shape_table_ = nullptr;
master_trainer_ = nullptr;
}
~MasterTrainerTest() {
delete master_trainer_;
delete shape_table_;
}
// Initializes the master_trainer_ and shape_table_.
// if load_from_tmp, then reloads a master trainer that was saved by a
// previous call in which it was false.
void LoadMasterTrainer() {
FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
const char* argv[] = {tr_file_name.c_str()};
int argc = 1;
STRING file_prefix;
delete master_trainer_;
delete shape_table_;
shape_table_ = nullptr;
tessoptind = 0;
master_trainer_ =
LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
EXPECT_TRUE(master_trainer_ != nullptr);
EXPECT_TRUE(shape_table_ != nullptr);
}
// EXPECTs that the distance between I and l in Arial is 0 and that the
// distance to 1 is significantly not 0.
void VerifyIl1() {
// Find the font id for Arial.
int font_id = master_trainer_->GetFontInfoId("Arial");
EXPECT_GE(font_id, 0);
// Track down the characters we are interested in.
int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
EXPECT_GT(unichar_I, 0);
int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
EXPECT_GT(unichar_l, 0);
int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
EXPECT_GT(unichar_1, 0);
// Now get the shape ids.
int shape_I = shape_table_->FindShape(unichar_I, font_id);
EXPECT_GE(shape_I, 0);
int shape_l = shape_table_->FindShape(unichar_l, font_id);
EXPECT_GE(shape_l, 0);
int shape_1 = shape_table_->FindShape(unichar_1, font_id);
EXPECT_GE(shape_1, 0);
float dist_I_l =
master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
// No tolerance here. We expect that I and l should match exactly.
EXPECT_EQ(0.0f, dist_I_l);
float dist_l_I =
master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
// BOTH ways.
EXPECT_EQ(0.0f, dist_l_I);
// l/1 on the other hand should be distinct.
float dist_l_1 =
master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
EXPECT_GT(dist_l_1, kMin1lDistance);
float dist_1_l =
master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
EXPECT_GT(dist_1_l, kMin1lDistance);
// So should I/1.
float dist_I_1 =
master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
EXPECT_GT(dist_I_1, kMin1lDistance);
float dist_1_I =
master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
EXPECT_GT(dist_1_I, kMin1lDistance);
}
// Objects declared here can be used by all tests in the test case for Foo.
ShapeTable* shape_table_;
MasterTrainer* master_trainer_;
#endif
};
// Tests that the MasterTrainer correctly loads its data and reaches the correct
// conclusion over the distance between Arial I l and 1.
TEST_F(MasterTrainerTest, Il1Test) {
#ifdef DISABLED_LEGACY_ENGINE
// Skip test because LoadTrainingData is missing.
GTEST_SKIP();
#else
// Initialize the master_trainer_ and load the Arial tr file.
LoadMasterTrainer();
VerifyIl1();
#endif
}
// Tests the ErrorCounter using a MockClassifier to check that it counts
// error categories correctly.
TEST_F(MasterTrainerTest, ErrorCounterTest) {
#ifdef DISABLED_LEGACY_ENGINE
// Skip test because LoadTrainingData is missing.
GTEST_SKIP();
#else
// Initialize the master_trainer_ from the saved tmp file.
LoadMasterTrainer();
// Add the space character to the shape_table_ if not already present to
// count junk.
if (shape_table_->FindShape(0, -1) < 0) shape_table_->AddShape(0, 0);
// Make a mock classifier.
tesseract::ShapeClassifier* shape_classifier =
new tesseract::MockClassifier(shape_table_);
// Get the accuracy report.
STRING accuracy_report;
master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0,
false, shape_classifier,
&accuracy_report);
LOG(INFO) << accuracy_report.c_str();
std::string result_string = accuracy_report.c_str();
std::vector<std::string> results =
absl::StrSplit(result_string, '\t', absl::SkipEmpty());
EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
int result_values[tesseract::CT_SIZE];
for (int i = 0; i < tesseract::CT_SIZE; ++i) {
EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
}
// These tests are more-or-less immune to additions to the number of
// categories or changes in the training data.
int num_samples = master_trainer_->GetSamples()->num_raw_samples();
EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
// Each of the TOPTOP errs also counts as a multi-unichar.
EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs,
result_values[tesseract::CT_OK_MULTI_UNICHAR]);
EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
delete shape_classifier;
#endif
}
} // namespace tesseract