Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup Token structure #153

Merged
merged 10 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add TokenType property
  • Loading branch information
guillaumekln committed Aug 3, 2020
commit 800ffdfc98af0b42dd4bc1f64f09e56308541fb6
8 changes: 7 additions & 1 deletion bindings/python/Python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,17 +447,23 @@ PYBIND11_MODULE(pyonmttok, m)
.value("NONE", onmt::CaseModifier::Type::None)
.export_values();

py::enum_<onmt::TokenType>(m, "TokenType")
.value("WORD", onmt::TokenType::Word)
.value("LEADING_SUBWORD", onmt::TokenType::LeadingSubword)
.value("TRAILING_SUBWORD", onmt::TokenType::TrailingSubword)
.export_values();

py::class_<onmt::Token>(m, "Token")
.def(py::init<>())
.def(py::init<std::string>())
.def_readwrite("surface", &onmt::Token::surface)
.def_readwrite("type", &onmt::Token::type)
.def_readwrite("join_left", &onmt::Token::join_left)
.def_readwrite("join_right", &onmt::Token::join_right)
.def_readwrite("spacer", &onmt::Token::spacer)
.def_readwrite("preserve", &onmt::Token::preserve)
.def_readwrite("features", &onmt::Token::features)
.def_readwrite("casing", &onmt::Token::case_type)
.def_readwrite("subword", &onmt::Token::subword)
.def("__eq__", &onmt::Token::operator==)
;

Expand Down
12 changes: 9 additions & 3 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,19 @@ The Token API allows to tokenize text into `pyonmttok.Token` objects. This API c
The `pyonmttok.Token` class has the following attributes:

* `surface`: a string, the token value
* `type`: a `pyonmttok.TokenType` value, the type of the token
* `join_left`: a boolean, whether the token should be joined to the token on the left or not
* `join_right`: a boolean, whether the token should be joined to the token on the right or not
* `spacer`: a boolean, whether the token is a spacer
* `preserve`: a boolean, whether joiners and spacers can be attached to this token or not
* `features`: a list of string, the features attached to the token
* `casing`: a `pyonmttok.Casing` value, the casing of the token
* `subword`: a boolean, whether the token is a subword
* `spacer`: a boolean, whether the token is prefixed by a SentencePiece spacer or not (only set when using SentencePiece)
* `casing`: a `pyonmttok.Casing` value, the casing of the token (only set when tokenizing with `case_feature` or `case_markup`)

The `pyonmttok.TokenType` enumeration can take the following values:

* `TokenType.WORD`
* `TokenType.LEADING_SUBWORD`
* `TokenType.TRAILING_SUBWORD`

The `pyonmttok.Casing` enumeration can take the following values:

Expand Down
10 changes: 5 additions & 5 deletions bindings/python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ def test_token_api_with_subword():
text = "BONJOUR MONDE"
def _check_subword(tokens):
assert len(tokens) == 5
assert not tokens[0].subword # bon
assert tokens[1].subword # j
assert tokens[2].subword # our
assert not tokens[3].subword # mon
assert tokens[4].subword # de
assert tokens[0].type == pyonmttok.TokenType.LEADING_SUBWORD # bon
assert tokens[1].type == pyonmttok.TokenType.TRAILING_SUBWORD # j
assert tokens[2].type == pyonmttok.TokenType.TRAILING_SUBWORD # our
assert tokens[3].type == pyonmttok.TokenType.LEADING_SUBWORD # mon
assert tokens[4].type == pyonmttok.TokenType.TRAILING_SUBWORD # de

tokens = tokenizer.tokenize(text, as_tokens=True)
_check_subword(tokens)
Expand Down
9 changes: 8 additions & 1 deletion include/onmt/Token.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@
namespace onmt
{

enum class TokenType
{
Word,
LeadingSubword,
TrailingSubword,
};

struct OPENNMTTOKENIZER_EXPORT Token
{
std::string surface;
TokenType type;
CaseModifier::Type case_type = CaseModifier::Type::None;
bool join_left = false;
bool join_right = false;
bool spacer = false;
bool preserve = false;
bool subword = false;
std::vector<std::string> features;

Token() = default;
Expand Down
2 changes: 1 addition & 1 deletion src/CaseModifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ namespace onmt
if (false
|| (!soft
&& case_type == Type::Uppercase
&& token.subword)
&& token.type == TokenType::TrailingSubword)
|| (soft
&& (case_type == Type::Uppercase
|| (case_type == Type::Capitalized && token.unicode_length() == 1)
Expand Down
13 changes: 9 additions & 4 deletions src/SubwordEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ namespace onmt
tokens.back().preserve = true;
}

for (size_t i = 0; i < tokens.size(); ++i)
if (token.has_case())
{
if (i > 0)
tokens[i].subword = true;
if (token.has_case())
for (size_t i = 0; i < tokens.size(); ++i)
{
auto case_type = token.case_type;
if (case_type == CaseModifier::Type::Capitalized && i > 0)
Expand All @@ -109,6 +107,13 @@ namespace onmt
}
}

if (tokens.size() > 1)
{
tokens.front().type = TokenType::LeadingSubword;
for (size_t i = 1; i < tokens.size(); ++i)
tokens[i].type = TokenType::TrailingSubword;
}

if (token.has_features())
{
for (auto& sub_token : tokens)
Expand Down