Skip to content

Commit

Permalink
Replace STRING by std::string in src/lstm
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Mar 13, 2021
1 parent 5989409 commit d6495d9
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 61 deletions.
7 changes: 2 additions & 5 deletions src/lstm/convolve.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ class Convolve : public Network {
Convolve(const std::string &name, int ni, int half_x, int half_y);
~Convolve() override = default;

STRING spec() const override {
STRING spec;
spec.add_str_int("C", half_x_ * 2 + 1);
spec.add_str_int(",", half_y_ * 2 + 1);
return spec;
std::string spec() const override {
return "C" + std::to_string(half_x_ * 2 + 1) + "," + std::to_string(half_y_ * 2 + 1);
}

// Writes to the given file. Returns false in case of error.
Expand Down
20 changes: 10 additions & 10 deletions src/lstm/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ class FullyConnected : public Network {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec;
std::string spec() const override {
std::string spec;
if (type_ == NT_TANH)
spec.add_str_int("Ft", no_);
spec += "Ft" + std::to_string(no_);
else if (type_ == NT_LOGISTIC)
spec.add_str_int("Fs", no_);
spec += "Fs" + std::to_string(no_);
else if (type_ == NT_RELU)
spec.add_str_int("Fr", no_);
spec += "Fr" + std::to_string(no_);
else if (type_ == NT_LINEAR)
spec.add_str_int("Fl", no_);
spec += "Fl" + std::to_string(no_);
else if (type_ == NT_POSCLIP)
spec.add_str_int("Fp", no_);
spec += "Fp" + std::to_string(no_);
else if (type_ == NT_SYMCLIP)
spec.add_str_int("Fn", no_);
spec += "Fn" + std::to_string(no_);
else if (type_ == NT_SOFTMAX)
spec.add_str_int("Fc", no_);
spec += "Fc" + std::to_string(no_);
else
spec.add_str_int("Fm", no_);
spec += "Fm" + std::to_string(no_);
return spec;
}

Expand Down
10 changes: 3 additions & 7 deletions src/lstm/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ class Input : public Network {
Input(const std::string &name, const StaticShape &shape);
~Input() override = default;

STRING spec() const override {
STRING spec;
spec.add_str_int("", shape_.batch());
spec.add_str_int(",", shape_.height());
spec.add_str_int(",", shape_.width());
spec.add_str_int(",", shape_.depth());
return spec;
std::string spec() const override {
return std::to_string(shape_.batch()) + "," + std::to_string(shape_.height()) + "," +
std::to_string(shape_.width()) + "," + std::to_string(shape_.depth());
}

// Returns the required shape input to the network.
Expand Down
12 changes: 6 additions & 6 deletions src/lstm/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ class LSTM : public Network {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec;
std::string spec() const override {
std::string spec;
if (type_ == NT_LSTM)
spec.add_str_int("Lfx", ns_);
spec += "Lfx" + std::to_string(ns_);
else if (type_ == NT_LSTM_SUMMARY)
spec.add_str_int("Lfxs", ns_);
spec += "Lfxs" + std::to_string(ns_);
else if (type_ == NT_LSTM_SOFTMAX)
spec.add_str_int("LS", ns_);
spec += "LS" + std::to_string(ns_);
else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
spec.add_str_int("LE", ns_);
spec += "LE" + std::to_string(ns_);
if (softmax_ != nullptr)
spec += softmax_->spec();
return spec;
Expand Down
7 changes: 2 additions & 5 deletions src/lstm/maxpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ class Maxpool : public Reconfig {
~Maxpool() override = default;

// Accessors.
STRING spec() const override {
STRING spec;
spec.add_str_int("Mp", y_scale_);
spec.add_str_int(",", x_scale_);
return spec;
std::string spec() const override {
return "Mp" + std::to_string(y_scale_) + "," + std::to_string(x_scale_);
}

// Reads from the given file. Returns false in case of error.
Expand Down
3 changes: 1 addition & 2 deletions src/lstm/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "networkio.h"
#include "serialis.h"
#include "static_shape.h"
#include "strngs.h" // for STRING
#include "tprintf.h"

#include <cmath>
Expand Down Expand Up @@ -141,7 +140,7 @@ class TESS_API Network {
const std::string &name() const {
return name_;
}
virtual STRING spec() const {
virtual std::string spec() const {
return "?";
}
bool TestFlag(NetworkFlags flag) const {
Expand Down
15 changes: 6 additions & 9 deletions src/lstm/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,23 @@ class Parallel : public Plumbing {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec;
std::string spec() const override {
std::string spec;
if (type_ == NT_PAR_2D_LSTM) {
// We have 4 LSTMs operating in parallel here, so the size of each is
// the number of outputs/4.
spec.add_str_int("L2xy", no_ / 4);
spec += "L2xy" + std::to_string(no_ / 4);
} else if (type_ == NT_PAR_RL_LSTM) {
// We have 2 LSTMs operating in parallel here, so the size of each is
// the number of outputs/2.
if (stack_[0]->type() == NT_LSTM_SUMMARY)
spec.add_str_int("Lbxs", no_ / 2);
spec += "Lbxs" + std::to_string(no_ / 2);
else
spec.add_str_int("Lbx", no_ / 2);
spec += "Lbx" + std::to_string(no_ / 2);
} else {
if (type_ == NT_REPLICATED) {
spec.add_str_int("R", stack_.size());
spec += "(";
spec += stack_[0]->spec();
spec += "R" + std::to_string(stack_.size()) + "(" + stack_[0]->spec();
} else {
spec = "(";
for (int i = 0; i < stack_.size(); ++i)
spec += stack_[i]->spec();
}
Expand Down
3 changes: 2 additions & 1 deletion src/lstm/plumbing.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
#include "strngs.h"

namespace tesseract {

Expand All @@ -36,7 +37,7 @@ class Plumbing : public Network {
StaticShape InputShape() const override {
return stack_[0]->InputShape();
}
STRING spec() const override {
std::string spec() const override {
return "Sub-classes of Plumbing must implement spec()!";
}

Expand Down
7 changes: 2 additions & 5 deletions src/lstm/reconfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ class Reconfig : public Network {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec;
spec.add_str_int("S", y_scale_);
spec.add_str_int(",", x_scale_);
return spec;
std::string spec() const override {
return "S" + std::to_string(y_scale_) + "," + std::to_string(x_scale_);
}

// Returns an integer reduction factor that the network applies to the
Expand Down
10 changes: 5 additions & 5 deletions src/lstm/reversed.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// File: reversed.h
// Description: Runs a single network on time-reversed input, reversing output.
// Author: Ray Smith
// Created: Thu May 02 08:38:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -35,15 +34,15 @@ class Reversed : public Plumbing {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec(type_ == NT_XREVERSED ? "Rx" : (type_ == NT_YREVERSED ? "Ry" : "Txy"));
std::string spec() const override {
std::string spec(type_ == NT_XREVERSED ? "Rx" : (type_ == NT_YREVERSED ? "Ry" : "Txy"));
// For most simple cases, we will output Rx<net> or Ry<net> where <net> is
// the network in stack_[0], but in the special case that <net> is an
// LSTM, we will just output the LSTM's spec modified to take the reversal
// into account. This is because when the user specified Lfy64, we actually
// generated TxyLfx64, and if the user specified Lrx64 we actually
// generated RxLfx64, and we want to display what the user asked for.
STRING net_spec = stack_[0]->spec();
std::string net_spec(stack_[0]->spec());
if (net_spec[0] == 'L') {
// Setup a from and to character according to the type of the reversal
// such that the LSTM spec gets modified to the spec that the user
Expand All @@ -59,7 +58,8 @@ class Reversed : public Plumbing {
if (net_spec[i] == from)
net_spec[i] = to;
}
return net_spec;
spec += net_spec;
return spec;
}
spec += net_spec;
return spec;
Expand Down
5 changes: 2 additions & 3 deletions src/lstm/series.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// File: series.h
// Description: Runs networks in series on the same input.
// Author: Ray Smith
// Created: Thu May 02 08:20:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -35,8 +34,8 @@ class Series : public Plumbing {
// be partially unknown ie zero).
StaticShape OutputShape(const StaticShape &input_shape) const override;

STRING spec() const override {
STRING spec("[");
std::string spec() const override {
std::string spec("[");
for (int i = 0; i < stack_.size(); ++i)
spec += stack_[i]->spec();
spec += "]";
Expand Down
4 changes: 2 additions & 2 deletions src/lstm/tfnetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class TFNetwork : public Network {
return output_shape_;
}

STRING spec() const override {
return spec_.c_str();
std::string spec() const override {
return spec_;
}

// Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
Expand Down
3 changes: 2 additions & 1 deletion src/training/unicharset/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ bool LSTMTrainer::InitNetwork(const char *network_spec, int append_index, int ne
return false;
}
network_str_ += network_spec;
tprintf("Built network:%s from request %s\n", network_->spec().c_str(), network_spec);
tprintf("Built network:%s from request %s\n",
network_->spec().c_str(), network_spec);
tprintf(
"Training parameters:\n Debug interval = %d,"
" weights = %g, learning rate = %g, momentum=%g\n",
Expand Down

0 comments on commit d6495d9

Please sign in to comment.