Skip to content

Commit

Permalink
Adds more unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Aug 2, 2022
1 parent 6e6add5 commit 69b88da
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 40 deletions.
67 changes: 39 additions & 28 deletions src/sentencepiece_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ std::vector<absl::string_view> ToPieceArray(const std::vector<std::string> &v) {
return out;
}

void ConvertToUnicodeSpansInternal(SentencePieceText *spt) {
if (spt == nullptr) return;

std::vector<int> utf8_to_unicode(spt->text().size() + 1, 0);
absl::string_view str = spt->text();
size_t prev = 0;
int ulen = 0;
while (!str.empty()) {
const size_t mblen = string_util::OneCharLen(str.data());
for (int i = prev; i < prev + mblen; ++i) {
utf8_to_unicode[i] = ulen;
}
++ulen;
prev += mblen;
str.remove_prefix(mblen);
}
utf8_to_unicode[prev] = ulen;

auto clip = [&](int s) {
return std::min<int>(std::max<int>(0, s), utf8_to_unicode.size() - 1);
};

for (auto &piece : *(spt->mutable_pieces())) {
piece.set_begin(utf8_to_unicode[clip(piece.begin())]);
piece.set_end(utf8_to_unicode[clip(piece.end())]);
}
}

} // namespace

ImmutableSentencePieceText::ImmutableSentencePieceText()
Expand Down Expand Up @@ -132,6 +160,10 @@ SentencePieceText *ImmutableSentencePieceText::mutable_proto() {
return rep_.get();
}

void ImmutableSentencePieceText::ConvertToUnicodeSpans() {
ConvertToUnicodeSpansInternal(mutable_proto());
}

util::bytes ImmutableSentencePieceText::SerializeAsString() const {
return spt_->SerializeAsString();
}
Expand Down Expand Up @@ -164,6 +196,13 @@ NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() {
return rep_.get();
}

void ImmutableNBestSentencePieceText::ConvertToUnicodeSpans() {
if (!mutable_proto()) return;
for (auto &spt : *(mutable_proto()->mutable_nbests())) {
ConvertToUnicodeSpansInternal(&spt);
}
}

util::bytes ImmutableNBestSentencePieceText::SerializeAsString() const {
return rep_ ? rep_->SerializeAsString() : "";
}
Expand Down Expand Up @@ -1048,34 +1087,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
// std::random_device.
void SetRandomGeneratorSeed(unsigned int seed);

void ConvertToUnicodeSpans(SentencePieceText *spt) {
if (spt == nullptr) return;

std::vector<int> utf8_to_unicode(spt->text().size() + 1, 0);
absl::string_view str = spt->text();
size_t prev = 0;
int ulen = 0;
while (!str.empty()) {
const size_t mblen = string_util::OneCharLen(str.data());
for (int i = prev; i < prev + mblen; ++i) {
utf8_to_unicode[i] = ulen;
}
++ulen;
prev += mblen;
str.remove_prefix(mblen);
}
utf8_to_unicode[prev] = ulen;

auto clip = [&](int s) {
return std::min<int>(std::max<int>(0, s), utf8_to_unicode.size() - 1);
};

for (auto &piece : *(spt->mutable_pieces())) {
piece.set_begin(utf8_to_unicode[clip(piece.begin())]);
piece.set_end(utf8_to_unicode[clip(piece.end())]);
}
}

namespace io {
util::Status LoadModelProto(absl::string_view filename,
ModelProto *model_proto) {
Expand Down
19 changes: 12 additions & 7 deletions src/sentencepiece_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#ifndef SWIG
namespace absl {
using std::string_view;
}
#endif // SWIG
} // namespace absl
#endif

namespace sentencepiece {
namespace util {
Expand Down Expand Up @@ -196,6 +196,9 @@ class ImmutableSentencePieceText {
// it returns the raw pointer managed by the shared_ptr.
SentencePieceText *mutable_proto();

// Converts the utf8 byte spans into Unicode char span.
void ConvertToUnicodeSpans();

friend class ImmutableNBestSentencePieceText;

private:
Expand Down Expand Up @@ -225,6 +228,8 @@ class ImmutableNBestSentencePieceText {
// it returns the raw pointer managed by the shared_ptr.
NBestSentencePieceText *mutable_proto();

void ConvertToUnicodeSpans();

private:
std::shared_ptr<NBestSentencePieceText> rep_;
};
Expand Down Expand Up @@ -415,14 +420,16 @@ class SentencePieceProcessor {
virtual util::Status Decode(const std::vector<int> &ids,
SentencePieceText *spt) const;

#ifdef SWIG
#ifdef SWIGPYTHON
#define CONVERT_TO_UNICODE_SPAN output.ConvertToUnicodeSpans();
#define SPP_SWIG_CHECK_AND_THROW \
if (!status.ok()) throw status;
#else
#define CONVERT_TO_UNICODE_SPAN
#define SPP_SWIG_CHECK_AND_THROW \
if (!status.ok()) { \
}
#endif // SWIG
#endif // SWIGPYTHON

#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
OutType output; \
Expand All @@ -439,6 +446,7 @@ class SentencePieceProcessor {
#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
OutType output; \
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
CONVERT_TO_UNICODE_SPAN; \
SPP_SWIG_CHECK_AND_THROW; \
return output;

Expand Down Expand Up @@ -707,9 +715,6 @@ class SentencePieceProcessor {
// std::random_device.
void SetRandomGeneratorSeed(unsigned int seed);

// Converts the utf8 byte spans into Unicode char span.
void ConvertToUnicodeSpans(SentencePieceText *spt);

#ifndef SWIG
// IO related functions to absorb model formats.
namespace io {
Expand Down
11 changes: 6 additions & 5 deletions src/sentencepiece_processor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1657,21 +1657,22 @@ TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) {

TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) {
auto make_spt = [&](const std::vector<std::string> &tokens) {
SentencePieceText spt;
ImmutableSentencePieceText ispt;
auto *spt = ispt.mutable_proto();
int prev = 0;
std::string text;
for (const auto &tok : tokens) {
auto *piece = spt.add_pieces();
auto *piece = spt->add_pieces();
piece->set_surface(tok);
piece->set_piece(tok);
piece->set_begin(prev);
piece->set_end(prev + tok.size());
prev += tok.size();
text += tok;
}
spt.set_text(text);
ConvertToUnicodeSpans(&spt);
return spt;
spt->set_text(text);
ispt.ConvertToUnicodeSpans();
return ispt;
};

{
Expand Down

0 comments on commit 69b88da

Please sign in to comment.