forked from arthenica/tesseract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnetwork.cpp
386 lines (360 loc) · 11.4 KB
/
network.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
///////////////////////////////////////////////////////////////////////
// File: network.cpp
// Description: Base class for neural network implementations.
// Author: Ray Smith
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
// Include automatically generated configuration file if running autoconf.
#ifdef HAVE_CONFIG_H
# include "config_auto.h"
#endif
#include "network.h"
#include <cstdlib>
// This base class needs to know about all its sub-classes because of the
// factory deserializing method: CreateFromFile.
#include <allheaders.h>
#include "convolve.h"
#include "fullyconnected.h"
#include "input.h"
#include "lstm.h"
#include "maxpool.h"
#include "parallel.h"
#include "reconfig.h"
#include "reversed.h"
#include "scrollview.h"
#include "series.h"
#include "statistc.h"
#ifdef INCLUDE_TENSORFLOW
# include "tfnetwork.h"
#endif
#include "tprintf.h"
namespace tesseract {
#ifndef GRAPHICS_DISABLED
// Min and max window sizes.
const int kMinWinSize = 500;
const int kMaxWinSize = 2000;
// Window frame sizes need adding on to make the content fit.
const int kXWinFrameSize = 30;
const int kYWinFrameSize = 80;
#endif // !GRAPHICS_DISABLED
// String names corresponding to the NetworkType enum.
// Keep in sync with NetworkType.
// Names used in Serialization to allow re-ordering/addition/deletion of
// layer types in NetworkType without invalidating existing network files.
static char const *const kTypeNames[NT_COUNT] = {
"Invalid", "Input",
"Convolve", "Maxpool",
"Parallel", "Replicated",
"ParBidiLSTM", "DepParUDLSTM",
"Par2dLSTM", "Series",
"Reconfig", "RTLReversed",
"TTBReversed", "XYTranspose",
"LSTM", "SummLSTM",
"Logistic", "LinLogistic",
"LinTanh", "Tanh",
"Relu", "Linear",
"Softmax", "SoftmaxNoCTC",
"LSTMSoftmax", "LSTMBinarySoftmax",
"TensorFlow",
};
Network::Network()
: type_(NT_NONE)
, training_(TS_ENABLED)
, needs_to_backprop_(true)
, network_flags_(0)
, ni_(0)
, no_(0)
, num_weights_(0)
, forward_win_(nullptr)
, backward_win_(nullptr)
, randomizer_(nullptr) {}
Network::Network(NetworkType type, const std::string &name, int ni, int no)
: type_(type)
, training_(TS_ENABLED)
, needs_to_backprop_(true)
, network_flags_(0)
, ni_(ni)
, no_(no)
, num_weights_(0)
, name_(name)
, forward_win_(nullptr)
, backward_win_(nullptr)
, randomizer_(nullptr) {}
// Suspends/Enables/Permanently disables training by setting the training_
// flag. Serialize and DeSerialize only operate on the run-time data if state
// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
// temporarily disable layers in state TS_ENABLED, allowing a trainer to
// serialize as if it were a recognizer.
// TS_RE_ENABLE will re-enable layers that were previously in any disabled
// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
// recognizer can be converted back to a trainer.
void Network::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) {
training_ = TS_ENABLED;
}
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) {
training_ = state;
}
} else {
training_ = state;
}
}
// Sets flags that control the action of the network. See NetworkFlags enum
// for bit values.
void Network::SetNetworkFlags(uint32_t flags) {
network_flags_ = flags;
}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) {
randomizer_ = randomizer;
return 0;
}
// Provides a pointer to a TRand for any networks that care to use it.
// Note that randomizer is a borrowed pointer that should outlive the network
// and should not be deleted by any of the networks.
void Network::SetRandomizer(TRand *randomizer) {
randomizer_ = randomizer;
}
// Sets needs_to_backprop_ to needs_backprop and returns true if
// needs_backprop || any weights in this network so the next layer forward
// can be told to produce backprop for this layer if needed.
bool Network::SetupNeedsBackprop(bool needs_backprop) {
needs_to_backprop_ = needs_backprop;
return needs_backprop || num_weights_ > 0;
}
// Writes to the given file. Returns false in case of error.
bool Network::Serialize(TFile *fp) const {
int8_t data = NT_NONE;
if (!fp->Serialize(&data)) {
return false;
}
std::string type_name = kTypeNames[type_];
if (!fp->Serialize(type_name)) {
return false;
}
data = training_;
if (!fp->Serialize(&data)) {
return false;
}
data = needs_to_backprop_;
if (!fp->Serialize(&data)) {
return false;
}
if (!fp->Serialize(&network_flags_)) {
return false;
}
if (!fp->Serialize(&ni_)) {
return false;
}
if (!fp->Serialize(&no_)) {
return false;
}
if (!fp->Serialize(&num_weights_)) {
return false;
}
uint32_t length = name_.length();
if (!fp->Serialize(&length)) {
return false;
}
return fp->Serialize(name_.c_str(), length);
}
static NetworkType getNetworkType(TFile *fp) {
int8_t data;
if (!fp->DeSerialize(&data)) {
return NT_NONE;
}
if (data == NT_NONE) {
std::string type_name;
if (!fp->DeSerialize(type_name)) {
return NT_NONE;
}
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
}
if (data == NT_COUNT) {
tprintf("Invalid network layer type:%s\n", type_name.c_str());
return NT_NONE;
}
}
return static_cast<NetworkType>(data);
}
// Reads from the given file. Returns nullptr in case of error.
// Determines the type of the serialized class and calls its DeSerialize
// on a new object of the appropriate type, which is returned.
Network *Network::CreateFromFile(TFile *fp) {
NetworkType type; // Type of the derived network class.
TrainingState training; // Are we currently training?
bool needs_to_backprop; // This network needs to output back_deltas.
int32_t network_flags; // Behavior control flags in NetworkFlags.
int32_t ni; // Number of input values.
int32_t no; // Number of output values.
int32_t num_weights; // Number of weights in this and sub-network.
std::string name; // A unique name for this layer.
int8_t data;
Network *network = nullptr;
type = getNetworkType(fp);
if (!fp->DeSerialize(&data)) {
return nullptr;
}
training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (!fp->DeSerialize(&data)) {
return nullptr;
}
needs_to_backprop = data != 0;
if (!fp->DeSerialize(&network_flags)) {
return nullptr;
}
if (!fp->DeSerialize(&ni)) {
return nullptr;
}
if (!fp->DeSerialize(&no)) {
return nullptr;
}
if (!fp->DeSerialize(&num_weights)) {
return nullptr;
}
if (!fp->DeSerialize(name)) {
return nullptr;
}
switch (type) {
case NT_CONVOLVE:
network = new Convolve(name.c_str(), ni, 0, 0);
break;
case NT_INPUT:
network = new Input(name.c_str(), ni, no);
break;
case NT_LSTM:
case NT_LSTM_SOFTMAX:
case NT_LSTM_SOFTMAX_ENCODED:
case NT_LSTM_SUMMARY:
network = new LSTM(name.c_str(), ni, no, no, false, type);
break;
case NT_MAXPOOL:
network = new Maxpool(name.c_str(), ni, 0, 0);
break;
// All variants of Parallel.
case NT_PARALLEL:
case NT_REPLICATED:
case NT_PAR_RL_LSTM:
case NT_PAR_UD_LSTM:
case NT_PAR_2D_LSTM:
network = new Parallel(name.c_str(), type);
break;
case NT_RECONFIG:
network = new Reconfig(name.c_str(), ni, 0, 0);
break;
// All variants of reversed.
case NT_XREVERSED:
case NT_YREVERSED:
case NT_XYTRANSPOSE:
network = new Reversed(name.c_str(), type);
break;
case NT_SERIES:
network = new Series(name.c_str());
break;
case NT_TENSORFLOW:
#ifdef INCLUDE_TENSORFLOW
network = new TFNetwork(name.c_str());
#else
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
#endif
break;
// All variants of FullyConnected.
case NT_SOFTMAX:
case NT_SOFTMAX_NO_CTC:
case NT_RELU:
case NT_TANH:
case NT_LINEAR:
case NT_LOGISTIC:
case NT_POSCLIP:
case NT_SYMCLIP:
network = new FullyConnected(name.c_str(), ni, no, type);
break;
default:
break;
}
if (network) {
network->training_ = training;
network->needs_to_backprop_ = needs_to_backprop;
network->network_flags_ = network_flags;
network->num_weights_ = num_weights;
if (!network->DeSerialize(fp)) {
delete network;
network = nullptr;
}
}
return network;
}
// Returns a random number in [-range, range].
TFloat Network::Random(TFloat range) {
ASSERT_HOST(randomizer_ != nullptr);
return randomizer_->SignedRand(range);
}
#ifndef GRAPHICS_DISABLED
// === Debug image display methods. ===
// Displays the image of the matrix to the forward window.
void Network::DisplayForward(const NetworkIO &matrix) {
Image image = matrix.ToPix();
ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_);
DisplayImage(image, forward_win_);
forward_win_->Update();
}
// Displays the image of the matrix to the backward window.
void Network::DisplayBackward(const NetworkIO &matrix) {
Image image = matrix.ToPix();
std::string window_name = name_ + "-back";
ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_);
DisplayImage(image, backward_win_);
backward_win_->Update();
}
// Creates the window if needed, otherwise clears it.
void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height,
ScrollView **window) {
if (*window == nullptr) {
int min_size = std::min(width, height);
if (min_size < kMinWinSize) {
if (min_size < 1) {
min_size = 1;
}
width = width * kMinWinSize / min_size;
height = height * kMinWinSize / min_size;
}
width += kXWinFrameSize;
height += kYWinFrameSize;
if (width > kMaxWinSize) {
width = kMaxWinSize;
}
if (height > kMaxWinSize) {
height = kMaxWinSize;
}
*window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords);
tprintf("Created window %s of size %d, %d\n", window_name, width, height);
} else {
(*window)->Clear();
}
}
// Displays the pix in the given window. and returns the height of the pix.
// The pix is pixDestroyed.
int Network::DisplayImage(Image pix, ScrollView *window) {
int height = pixGetHeight(pix);
window->Draw(pix, 0, 0);
pix.destroy();
return height;
}
#endif // !GRAPHICS_DISABLED
} // namespace tesseract.