Skip to content

Commit

Permalink
Declare tokenization options when setting a subword vocabulary (#186)
Browse files Browse the repository at this point in the history
This is required to make BPE vocabulary restriction work with spacer
annotation, joiner_new, etc.

This PR also make the subword encoder `const` when assigned to a
tokenizer instance.
  • Loading branch information
guillaumekln authored Oct 26, 2020
1 parent caf67f1 commit 35aaea7
Show file tree
Hide file tree
Showing 12 changed files with 119 additions and 68 deletions.
10 changes: 5 additions & 5 deletions bindings/python/Python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ class TokenizerWrapper
vocabulary_threshold = bpe_vocab_threshold;
}

if (subword_encoder && !vocabulary_path.empty())
subword_encoder->load_vocabulary(vocabulary_path, vocabulary_threshold);

onmt::Tokenizer::Options options;
options.mode = onmt::Tokenizer::str_to_mode(mode);
options.no_substitution = no_substitution;
Expand All @@ -157,8 +154,11 @@ class TokenizerWrapper
if (!segment_alphabet.is(py::none()))
options.segment_alphabet = to_std_vector<std::string>(segment_alphabet.cast<py::list>());

if (subword_encoder && !vocabulary_path.empty())
subword_encoder->load_vocabulary(vocabulary_path, vocabulary_threshold, &options);

_tokenizer.reset(new onmt::Tokenizer(options,
std::shared_ptr<onmt::SubwordEncoder>(subword_encoder)));
std::shared_ptr<const onmt::SubwordEncoder>(subword_encoder)));
}

py::object tokenize(const std::string& text, const bool as_token_objects) const
Expand Down Expand Up @@ -344,7 +344,7 @@ class SubwordLearnerWrapper

auto* new_subword_encoder = create_subword_encoder(model_path);
auto* new_tokenizer = new onmt::Tokenizer(*_tokenizer);
new_tokenizer->set_subword_encoder(std::shared_ptr<onmt::SubwordEncoder>(new_subword_encoder));
new_tokenizer->set_subword_encoder(std::shared_ptr<const onmt::SubwordEncoder>(new_subword_encoder));
return TokenizerWrapper(new_tokenizer);
}

Expand Down
5 changes: 3 additions & 2 deletions cli/tokenize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ int main(int argc, char* argv[])
vm["sp_alpha"].as<float>());
}

auto options = build_tokenization_options(vm);
if (subword_encoder && !vocabulary.empty())
subword_encoder->load_vocabulary(vocabulary, vocabulary_threshold);
subword_encoder->load_vocabulary(vocabulary, vocabulary_threshold, &options);

onmt::Tokenizer tokenizer(build_tokenization_options(vm),
onmt::Tokenizer tokenizer(std::move(options),
std::shared_ptr<onmt::SubwordEncoder>(subword_encoder));

tokenizer.tokenize_stream(std::cin, std::cout, vm["num_threads"].as<int>());
Expand Down
5 changes: 0 additions & 5 deletions docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,6 @@ The following line formats are accepted:
* `<token><tab><frequency>`
* `<token>` (the token frequency is set to 1)

This feature currently requires subword encoders to be used with their "natural" tokenization settings, that is:

* SentencePiece: `--mode none --spacer_annotate`
* BPE: `--joiner_annotate`

### `vocabulary_threshold` (int, default: `0`)

When using `vocabulary_path`, any words with a frequency lower than `vocabulary_threshold` will be treated as OOV.
Expand Down
10 changes: 6 additions & 4 deletions include/onmt/BPE.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ namespace onmt
std::vector<std::string> encode(const std::string& str) const override;
std::vector<Token> encode_and_annotate(const Token& token) const override;

void set_vocabulary(const std::vector<std::string>& vocabulary) override;
void set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* options = nullptr) override;
void reset_vocabulary() override;

void set_joiner(const std::string& joiner)
{
_joiner = joiner;
_tokenization_options.joiner = joiner;
}

void set_dropout(const float dropout)
Expand All @@ -40,10 +41,11 @@ namespace onmt
bool _suffix;
bool _case_insensitive;
std::pair<int, int> _version;

std::string _joiner;
float _dropout;

// Tokenization options used to produce the vocabulary passed to set_vocabulary.
Tokenizer::Options _tokenization_options;

std::unordered_map<std::string, int> _codes;
std::unordered_map<std::string, std::pair<std::string, std::string> > _codes_reverse;
std::unordered_set<std::string> _bpe_vocab;
Expand Down
4 changes: 3 additions & 1 deletion include/onmt/SentencePiece.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace onmt
SentencePiece(const std::string& model_path, int nbest_size, float alpha);
~SentencePiece();

void set_vocabulary(const std::vector<std::string>& vocabulary) override;
void update_tokenization_options(Tokenizer::Options& options) const override;
void set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* options = nullptr) override;
void reset_vocabulary() override;
void enable_regularization(int nbest_size, float alpha);

Expand Down
12 changes: 9 additions & 3 deletions include/onmt/SubwordEncoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <vector>

#include "onmt/opennmttokenizer_export.h"
#include "onmt/Token.h"
#include "onmt/Tokenizer.h"

namespace onmt
{
Expand All @@ -14,8 +14,14 @@ namespace onmt
public:
virtual ~SubwordEncoder() = default;

virtual void load_vocabulary(const std::string& path, int frequency_threshold);
virtual void set_vocabulary(const std::vector<std::string>& vocabulary);
// Maybe update the tokenization options for this subword encoder.
virtual void update_tokenization_options(Tokenizer::Options& options) const;

virtual void load_vocabulary(const std::string& path,
int frequency_threshold,
const Tokenizer::Options* tokenization_options = nullptr);
virtual void set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* tokenization_options = nullptr);
virtual void reset_vocabulary();

virtual std::vector<std::string> encode(const std::string& str) const = 0;
Expand Down
12 changes: 7 additions & 5 deletions include/onmt/Tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

#include "onmt/opennmttokenizer_export.h"
#include "onmt/ITokenizer.h"
#include "onmt/SubwordEncoder.h"
#include "onmt/Token.h"

namespace onmt
{

void set_random_seed(const unsigned int seed);

class SubwordEncoder;

// This Tokenizer implements the behaviour of OpenNMT's tools/tokenize.lua.
class OPENNMTTOKENIZER_EXPORT Tokenizer: public ITokenizer
{
Expand Down Expand Up @@ -61,7 +63,7 @@ namespace onmt
static const std::string ph_marker_close;

Tokenizer(Options options,
const std::shared_ptr<SubwordEncoder>& subword_encoder = nullptr);
const std::shared_ptr<const SubwordEncoder>& subword_encoder = nullptr);

using ITokenizer::tokenize;
using ITokenizer::detokenize;
Expand Down Expand Up @@ -94,7 +96,7 @@ namespace onmt
const std::vector<std::vector<std::string> >& features,
Ranges& ranges, bool merge_ranges = false) const override;

void set_subword_encoder(const std::shared_ptr<SubwordEncoder>& subword_encoder);
void set_subword_encoder(const std::shared_ptr<const SubwordEncoder>& subword_encoder);

const Options& options() const
{
Expand All @@ -106,7 +108,7 @@ namespace onmt
static const int number_alphabet = -3;

Options _options;
std::shared_ptr<SubwordEncoder> _subword_encoder;
std::shared_ptr<const SubwordEncoder> _subword_encoder;

void tokenize_on_placeholders(const std::string& text,
std::vector<Token>& annotated_tokens) const;
Expand Down Expand Up @@ -171,7 +173,7 @@ namespace onmt
// External subword encoder constructor.
// Note: the tokenizer takes ownership of the subword_encoder pointer.
Tokenizer(Mode mode,
SubwordEncoder* subword_encoder,
const SubwordEncoder* subword_encoder,
int flags = Flags::None,
const std::string& joiner = joiner_marker);

Expand Down
40 changes: 31 additions & 9 deletions src/BPE.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ namespace onmt
, _suffix(true)
, _case_insensitive(false)
, _version(0, 0)
, _joiner(Tokenizer::joiner_marker)
, _dropout(check_dropout(dropout))
{
load_model(model_path);

// For backward compatibility, assume the tokenization uses joiner annotation.
_tokenization_options.joiner_annotate = true;
_tokenization_options.joiner = Tokenizer::joiner_marker;
}

BPE::BPE(const std::string& model_path, const std::string& joiner, const float dropout)
Expand All @@ -47,10 +50,13 @@ namespace onmt
, _suffix(true)
, _case_insensitive(false)
, _version(0, 0)
, _joiner(joiner)
, _dropout(check_dropout(dropout))
{
load_model(model_path);

// For backward compatibility, assume the tokenization uses joiner annotation.
_tokenization_options.joiner_annotate = true;
_tokenization_options.joiner = joiner;
}

void BPE::load_model(const std::string& model_path)
Expand Down Expand Up @@ -297,10 +303,13 @@ namespace onmt
}
}

void BPE::set_vocabulary(const std::vector<std::string>& vocabulary)
void BPE::set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* options)
{
_bpe_vocab.clear();
_bpe_vocab.insert(vocabulary.begin(), vocabulary.end());
if (options)
_tokenization_options = *options;
}

void BPE::reset_vocabulary()
Expand All @@ -315,12 +324,25 @@ namespace onmt

bool BPE::in_vocabulary(const onmt::Token& token) const
{
// TODO: support joiner_new, spacer_annotate, spacer_new.
if (token.preserve || (!token.join_left && !token.join_right))
return in_vocabulary(token.surface);
return in_vocabulary((token.join_left ? _joiner : "")
+ token.surface
+ (token.join_right ? _joiner : ""));
std::string surface = token.surface;

if (!token.preserve)
{
if (_tokenization_options.joiner_annotate && !_tokenization_options.joiner_new)
{
if (token.join_left)
surface = _tokenization_options.joiner + surface;
if (token.join_right)
surface = surface + _tokenization_options.joiner;
}
else if (_tokenization_options.spacer_annotate && !_tokenization_options.spacer_new)
{
if (!token.join_left)
surface = Tokenizer::spacer_marker + surface;
}
}

return in_vocabulary(surface);
}

std::vector<Token> BPE::check_vocab_and_split(std::vector<Token> pieces) const
Expand Down
22 changes: 21 additions & 1 deletion src/SentencePiece.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,28 @@ namespace onmt
delete _processor;
}

void SentencePiece::set_vocabulary(const std::vector<std::string>& vocabulary)
void SentencePiece::update_tokenization_options(Tokenizer::Options& options) const
{
// Maybe enable SentencePiece compatibility mode.
if (options.mode == Tokenizer::Mode::None
&& !options.joiner_annotate
&& !options.spacer_annotate)
{
options.spacer_annotate = true;
options.no_substitution = true;
}
}

void SentencePiece::set_vocabulary(const std::vector<std::string>& vocabulary,
const Tokenizer::Options* options)
{
if (options
&& (options->mode != Tokenizer::Mode::None
|| !options->spacer_annotate
|| options->spacer_new))
throw std::invalid_argument("SentencePiece vocabulary restriction requires the tokenization "
"to use the \"none\" mode and \"spacer_annotate\" "
"(same as spm_encode)");
auto status = _processor->SetVocabulary(vocabulary);
if (!status.ok())
throw std::invalid_argument(status.ToString());
Expand Down
14 changes: 9 additions & 5 deletions src/SubwordEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
namespace onmt
{

void SubwordEncoder::load_vocabulary(const std::string& path, int frequency_threshold)
void SubwordEncoder::update_tokenization_options(Tokenizer::Options&) const
{
}

void SubwordEncoder::load_vocabulary(const std::string& path,
int frequency_threshold,
const Tokenizer::Options* options)
{
std::ifstream in(path);
if (!in)
Expand Down Expand Up @@ -39,17 +45,15 @@ namespace onmt
vocab.emplace_back(std::move(token));
}

set_vocabulary(vocab);
set_vocabulary(vocab, options);
}

void SubwordEncoder::set_vocabulary(const std::vector<std::string>&)
void SubwordEncoder::set_vocabulary(const std::vector<std::string>&, const Tokenizer::Options*)
{
return;
}

void SubwordEncoder::reset_vocabulary()
{
return;
}

std::vector<Token> SubwordEncoder::encode_and_annotate(const std::vector<Token>& tokens) const
Expand Down
41 changes: 13 additions & 28 deletions src/Tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ namespace onmt
}
}

Tokenizer::Tokenizer(Options options, const std::shared_ptr<SubwordEncoder>& subword_encoder)
Tokenizer::Tokenizer(Options options,
const std::shared_ptr<const SubwordEncoder>& subword_encoder)
: _options(std::move(options))
{
_options.validate();
Expand All @@ -124,24 +125,27 @@ namespace onmt
_options.validate();
if (!model_path.empty())
{
SubwordEncoder* subword_encoder = nullptr;
if (flags & Flags::SentencePieceModel)
set_subword_encoder(std::make_shared<SentencePiece>(model_path));
subword_encoder = new SentencePiece(model_path);
else
set_subword_encoder(std::make_shared<BPE>(model_path));
subword_encoder = new BPE(model_path);

if (!vocab_path.empty())
_subword_encoder->load_vocabulary(vocab_path, vocab_threshold);
subword_encoder->load_vocabulary(vocab_path, vocab_threshold, &_options);

set_subword_encoder(std::shared_ptr<const SubwordEncoder>(subword_encoder));
}
}

Tokenizer::Tokenizer(Mode mode,
SubwordEncoder* subword_encoder,
const SubwordEncoder* subword_encoder,
int flags,
const std::string& joiner)
: _options(mode, flags, joiner)
{
_options.validate();
set_subword_encoder(std::shared_ptr<SubwordEncoder>(subword_encoder));
set_subword_encoder(std::shared_ptr<const SubwordEncoder>(subword_encoder));
}

Tokenizer::Tokenizer(const std::string& sp_model_path,
Expand Down Expand Up @@ -917,30 +921,11 @@ namespace onmt
_options.joiner_annotate = _options.spacer_annotate = false;
}

void Tokenizer::set_subword_encoder(const std::shared_ptr<SubwordEncoder>& subword_encoder)
void Tokenizer::set_subword_encoder(const std::shared_ptr<const SubwordEncoder>& subword_encoder)
{
_subword_encoder = subword_encoder;

// TODO: clean this up, declare a base method "declare_tokenization_options".
auto* encoder = _subword_encoder.get();
auto* sp = encoder ? dynamic_cast<SentencePiece*>(encoder) : nullptr;
auto* bpe = encoder && !sp ? dynamic_cast<BPE*>(encoder) : nullptr;

if (sp)
{
// Maybe enable SentencePiece compatibility mode.
if (_options.mode == Mode::None
&& !_options.joiner_annotate
&& !_options.spacer_annotate)
{
_options.spacer_annotate = true;
_options.no_substitution = true;
}
}
else if (bpe)
{
bpe->set_joiner(_options.joiner);
}
if (_subword_encoder)
_subword_encoder->update_tokenization_options(_options);
}

bool Tokenizer::add_alphabet_to_segment(const std::string& alphabet)
Expand Down
Loading

0 comments on commit 35aaea7

Please sign in to comment.