Skip to content

Commit

Permalink
Prepare using float instead of double for LSTM calculations
Browse files Browse the repository at this point in the history
The new header file ccutils/tesstypes.h also prepares support
for larger images by introducing a new data type for image
size and coordinates (still unused).

FloatToDouble is now a local function.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jul 24, 2021
1 parent c3fb050 commit 66b77e6
Show file tree
Hide file tree
Showing 27 changed files with 265 additions and 221 deletions.
6 changes: 6 additions & 0 deletions Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -150,38 +150,44 @@ endif
if MARCH_NATIVE_OPT
libtesseract_native_la_CXXFLAGS += -march=native -mtune=native
endif
libtesseract_native_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_native_la_SOURCES = src/arch/dotproduct.cpp

if HAVE_AVX
libtesseract_avx_la_CXXFLAGS = -mavx
libtesseract_avx_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_avx_la_SOURCES = src/arch/dotproductavx.cpp
libtesseract_la_LIBADD += libtesseract_avx.la
noinst_LTLIBRARIES += libtesseract_avx.la
endif

if HAVE_AVX2
libtesseract_avx2_la_CXXFLAGS = -mavx2
libtesseract_avx2_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_avx2_la_SOURCES = src/arch/intsimdmatrixavx2.cpp
libtesseract_la_LIBADD += libtesseract_avx2.la
noinst_LTLIBRARIES += libtesseract_avx2.la
endif

if HAVE_FMA
libtesseract_fma_la_CXXFLAGS = -mfma
libtesseract_fma_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_fma_la_SOURCES = src/arch/dotproductfma.cpp
libtesseract_la_LIBADD += libtesseract_fma.la
noinst_LTLIBRARIES += libtesseract_fma.la
endif

if HAVE_SSE4_1
libtesseract_sse_la_CXXFLAGS = -msse4.1
libtesseract_sse_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_sse_la_SOURCES = src/arch/dotproductsse.cpp src/arch/intsimdmatrixsse.cpp
libtesseract_la_LIBADD += libtesseract_sse.la
noinst_LTLIBRARIES += libtesseract_sse.la
endif

if HAVE_NEON
libtesseract_neon_la_CXXFLAGS = $(NEON_CXXFLAGS)
libtesseract_neon_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_neon_la_SOURCES = src/arch/intsimdmatrixneon.cpp
libtesseract_la_LIBADD += libtesseract_neon.la
noinst_LTLIBRARIES += libtesseract_neon.la
Expand Down
6 changes: 3 additions & 3 deletions src/arch/dotproduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
namespace tesseract {

// Computes and returns the dot product of the two n-vectors u and v.
double DotProductNative(const double *u, const double *v, int n) {
double total = 0.0;
TFloat DotProductNative(const TFloat *u, const TFloat *v, int n) {
TFloat total = 0;
#if defined(OPENMP_SIMD) || defined(_OPENMP)
#pragma omp simd reduction(+:total)
#endif
for (int k = 0; k < n; ++k) {
for (int k = 0; k < n; k++) {
total += u[k] * v[k];
}
return total;
Expand Down
10 changes: 6 additions & 4 deletions src/arch/dotproduct.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
#ifndef TESSERACT_ARCH_DOTPRODUCT_H_
#define TESSERACT_ARCH_DOTPRODUCT_H_

#include "tesstypes.h"

namespace tesseract {

// Computes and returns the dot product of the n-vectors u and v.
double DotProductNative(const double *u, const double *v, int n);
TFloat DotProductNative(const TFloat *u, const TFloat *v, int n);

// Uses Intel AVX intrinsics to access the SIMD instruction set.
double DotProductAVX(const double *u, const double *v, int n);
TFloat DotProductAVX(const TFloat *u, const TFloat *v, int n);

// Use Intel FMA.
double DotProductFMA(const double *u, const double *v, int n);
TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n);

// Uses Intel SSE intrinsics to access the SIMD instruction set.
double DotProductSSE(const double *u, const double *v, int n);
TFloat DotProductSSE(const TFloat *u, const TFloat *v, int n);

} // namespace tesseract.

Expand Down
2 changes: 1 addition & 1 deletion src/arch/intsimdmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t> &w, std::vector<int8_t>
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t> &w,
const std::vector<double> &scales, const int8_t *u, double *v) {
const std::vector<TFloat> &scales, const int8_t *u, TFloat *v) {
int num_out = w.dim1();
int num_in = w.dim2() - 1;
// Base implementation.
Expand Down
10 changes: 6 additions & 4 deletions src/arch/intsimdmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <cstdint>
#include <vector>

#include "tesstypes.h"

namespace tesseract {

template <class T>
Expand Down Expand Up @@ -78,8 +80,8 @@ struct TESS_API IntSimdMatrix {
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
// Computes the base C++ implementation.
static void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t> &w, const std::vector<double> &scales,
const int8_t *u, double *v);
static void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t> &w, const std::vector<TFloat> &scales,
const int8_t *u, TFloat *v);

// Rounds the input up to a multiple of the given factor.
static int Roundup(int input, int factor) {
Expand All @@ -95,8 +97,8 @@ struct TESS_API IntSimdMatrix {
// RoundInputs above.
// The input will be over-read to the extent of the padding. There are no
// alignment requirements.
using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const double *, const int8_t *,
double *);
using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const TFloat *, const int8_t *,
TFloat *);
MatrixDotVectorFunction matrixDotVectorFunction;

// Number of 32 bit outputs held in each register.
Expand Down
12 changes: 7 additions & 5 deletions src/arch/intsimdmatrixneon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#if defined(__ARM_NEON)

# include "intsimdmatrix.h"
# include "tesstypes.h"

# include <algorithm>
# include <cstdint>
Expand Down Expand Up @@ -52,9 +53,9 @@ constexpr int kNumInputsPerGroup = 8;
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static inline void PartialMatrixDotVector8(const int8_t *__restrict wi,
const double *__restrict scales,
const TFloat *__restrict scales,
const int8_t *__restrict u, int num_in,
double *__restrict v, int num_out) {
TFloat *__restrict v, int num_out) {
// Initialize all the results to 0.
int32x4_t result0123 = {0, 0, 0, 0};
int32x4_t result4567 = {0, 0, 0, 0};
Expand Down Expand Up @@ -163,8 +164,8 @@ static inline void PartialMatrixDotVector8(const int8_t *__restrict wi,
}
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
const int8_t *u, double *v) {
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
const int8_t *u, TFloat *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
Expand Down Expand Up @@ -196,7 +197,8 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixNEON = {
// Number of 8 bit inputs in the inputs register.
kNumInputsPerRegister,
// Number of inputs in each weight group.
kNumInputsPerGroup};
kNumInputsPerGroup
};

} // namespace tesseract.

Expand Down
13 changes: 7 additions & 6 deletions src/arch/intsimdmatrixsse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ static int32_t IntDotProductSSE(const int8_t *u, const int8_t *v, int n) {
}

// Computes part of matrix.vector v = Wu. Computes 1 result.
static void PartialMatrixDotVector1(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
double total = IntDotProductSSE(u, wi, num_in);
static void PartialMatrixDotVector1(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
TFloat total = IntDotProductSSE(u, wi, num_in);
// Add in the bias and correct for integer values.
*v = (total + wi[num_in] * INT8_MAX) * *scales;
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
const int8_t *u, double *v) {
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
const int8_t *u, TFloat *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
int output = 0;
Expand All @@ -99,7 +99,8 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = {
// Number of 8 bit inputs in the inputs register.
1,
// Number of inputs in each weight group.
1};
1
};

} // namespace tesseract.

Expand Down
12 changes: 6 additions & 6 deletions src/arch/simddetect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,26 @@ bool SIMDDetect::sse_available_;
#endif

#if defined(HAVE_FRAMEWORK_ACCELERATE)
static double DotProductAccelerate(const double* u, const double* v, int n) {
double total = 0.0;
static TFloat DotProductAccelerate(const TFloat* u, const TFloat* v, int n) {
TFloat total = 0;
const int stride = 1;
vDSP_dotprD(u, stride, v, stride, &total, n);
return total;
}
#endif

// Computes and returns the dot product of the two n-vectors u and v.
static double DotProductGeneric(const double *u, const double *v, int n) {
double total = 0.0;
static TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) {
TFloat total = 0;
for (int k = 0; k < n; ++k) {
total += u[k] * v[k];
}
return total;
}

// Compute dot product using std::inner_product.
static double DotProductStdInnerProduct(const double *u, const double *v, int n) {
return std::inner_product(u, u + n, v, 0.0);
static TFloat DotProductStdInnerProduct(const TFloat *u, const TFloat *v, int n) {
return std::inner_product(u, u + n, v, static_cast<TFloat>(0));
}

static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion src/arch/simddetect.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
#define TESSERACT_ARCH_SIMDDETECT_H_

#include <tesseract/export.h>
#include "tesstypes.h"

namespace tesseract {

// Function pointer for best calculation of dot product.
using DotProductFunction = double (*)(const double *, const double *, int);
using DotProductFunction = TFloat (*)(const TFloat *, const TFloat *, int);
extern DotProductFunction DotProduct;

// Architecture detector. Add code here to detect any other architectures for
Expand Down
32 changes: 32 additions & 0 deletions src/ccutil/tesstypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
///////////////////////////////////////////////////////////////////////
// File: tesstypes.h
// Description: Simple data types used by Tesseract code.
// Author: Stefan Weil
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

#ifndef TESSERACT_TESSTYPES_H
#define TESSERACT_TESSTYPES_H

#include <cstdint> // for int16_t

namespace tesseract {

// Image dimensions (width and height, coordinates).
using TDimension = int16_t;

// Floating point data type used for LSTM calculations.
using TFloat = double;

}

#endif // TESSERACT_TESSTYPES_H
18 changes: 9 additions & 9 deletions src/lstm/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void FullyConnected::Forward(bool debug, const NetworkIO &input,
// Thread-local pointer to temporary storage.
int thread_id = 0;
#endif
double *temp_line = temp_lines[thread_id];
TFloat *temp_line = temp_lines[thread_id];
if (input.int_mode()) {
ForwardTimeStep(input.i(t), t, temp_line);
} else {
Expand Down Expand Up @@ -200,7 +200,7 @@ void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray
}
}

void FullyConnected::ForwardTimeStep(int t, double *output_line) {
void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) {
if (type_ == NT_TANH) {
FuncInplace<GFunc>(no_, output_line);
} else if (type_ == NT_LOGISTIC) {
Expand All @@ -218,7 +218,7 @@ void FullyConnected::ForwardTimeStep(int t, double *output_line) {
}
}

void FullyConnected::ForwardTimeStep(const double *d_input, int t, double *output_line) {
void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (IsTraining() && external_source_ == nullptr) {
source_t_.WriteStrided(t, d_input);
Expand All @@ -227,7 +227,7 @@ void FullyConnected::ForwardTimeStep(const double *d_input, int t, double *outpu
ForwardTimeStep(t, output_line);
}

void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, double *output_line) {
void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) {
// input is copied to source_ line-by-line for cache coherency.
weights_.MatrixDotVector(i_input, output_line);
ForwardTimeStep(t, output_line);
Expand Down Expand Up @@ -265,11 +265,11 @@ bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkSc
for (int t = 0; t < width; ++t) {
int thread_id = 0;
#endif
double *backprop = nullptr;
TFloat *backprop = nullptr;
if (needs_to_backprop_) {
backprop = temp_backprops[thread_id];
}
double *curr_errors = errors[thread_id];
TFloat *curr_errors = errors[thread_id];
BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
if (backprop != nullptr) {
back_deltas->WriteTimeStep(t, backprop);
Expand All @@ -287,8 +287,8 @@ bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkSc
return false; // No point going further back.
}

void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors,
TransposedArray *errors_t, double *backprop) {
void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors,
TransposedArray *errors_t, TFloat *backprop) {
if (type_ == NT_TANH) {
acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
} else if (type_ == NT_LOGISTIC) {
Expand Down Expand Up @@ -328,7 +328,7 @@ void FullyConnected::Update(float learning_rate, float momentum, float adam_beta
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void FullyConnected::CountAlternators(const Network &other, double *same, double *changed) const {
void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
ASSERT_HOST(other.type() == type_);
const auto *fc = static_cast<const FullyConnected *>(&other);
weights_.CountAlternators(fc->weights_, same, changed);
Expand Down
13 changes: 7 additions & 6 deletions src/lstm/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "network.h"
#include "networkscratch.h"
#include "tesstypes.h"

namespace tesseract {

Expand Down Expand Up @@ -90,17 +91,17 @@ class FullyConnected : public Network {
NetworkScratch *scratch, NetworkIO *output) override;
// Components of Forward so FullyConnected can be reused inside LSTM.
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose);
void ForwardTimeStep(int t, double *output_line);
void ForwardTimeStep(const double *d_input, int t, double *output_line);
void ForwardTimeStep(const int8_t *i_input, int t, double *output_line);
void ForwardTimeStep(int t, TFloat *output_line);
void ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line);
void ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line);

// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
NetworkIO *back_deltas) override;
// Components of Backward so FullyConnected can be reused inside LSTM.
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors,
TransposedArray *errors_t, double *backprop);
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors,
TransposedArray *errors_t, TFloat *backprop);
void FinishBackward(const TransposedArray &errors_t);

// Updates the weights using the given learning rate, momentum and adam_beta.
Expand All @@ -109,7 +110,7 @@ class FullyConnected : public Network {
// Sums the products of weight updates in *this and other, splitting into
// positive (same direction) in *same and negative (different direction) in
// *changed.
void CountAlternators(const Network &other, double *same, double *changed) const override;
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override;

protected:
// Weight arrays of size [no, ni + 1].
Expand Down
Loading

0 comments on commit 66b77e6

Please sign in to comment.