Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup Token structure #153

Merged
merged 10 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions bindings/python/Python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <memory>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <onmt/Tokenizer.h>
#include <onmt/BPE.h>
Expand Down Expand Up @@ -175,9 +176,9 @@ class TokenizerWrapper
_tokenizer.reset(tokenizer);
}

py::object tokenize(const std::string& text, const bool as_tokens) const
py::object tokenize(const std::string& text, const bool as_token_objects) const
{
if (as_tokens)
if (as_token_objects)
{
std::vector<onmt::Token> tokens;
_tokenizer->tokenize(text, tokens);
Expand Down Expand Up @@ -447,18 +448,23 @@ PYBIND11_MODULE(pyonmttok, m)
.value("NONE", onmt::CaseModifier::Type::None)
.export_values();

py::enum_<onmt::TokenType>(m, "TokenType")
.value("WORD", onmt::TokenType::Word)
.value("LEADING_SUBWORD", onmt::TokenType::LeadingSubword)
.value("TRAILING_SUBWORD", onmt::TokenType::TrailingSubword)
.export_values();

py::class_<onmt::Token>(m, "Token")
.def(py::init<>())
.def(py::init<std::string>())
.def_readwrite("surface", &onmt::Token::surface)
.def_readwrite("type", &onmt::Token::type)
.def_readwrite("join_left", &onmt::Token::join_left)
.def_readwrite("join_right", &onmt::Token::join_right)
.def_readwrite("spacer", &onmt::Token::spacer)
.def_readwrite("preserve", &onmt::Token::preserve)
.def_readwrite("features", &onmt::Token::features)
.def_readwrite("casing", &onmt::Token::case_type)
.def_readwrite("begin_case_region", &onmt::Token::begin_case_region)
.def_readwrite("end_case_region", &onmt::Token::end_case_region)
.def("__eq__", &onmt::Token::operator==)
;

Expand Down Expand Up @@ -491,7 +497,7 @@ PYBIND11_MODULE(pyonmttok, m)
py::arg("segment_alphabet")=py::list())
.def("tokenize", &TokenizerWrapper::tokenize,
py::arg("text"),
py::arg("as_tokens")=false)
py::arg("as_token_objects")=false)
.def("serialize_tokens", &TokenizerWrapper::serialize_tokens,
py::arg("tokens"))
.def("deserialize_tokens", &TokenizerWrapper::deserialize_tokens,
Expand Down
72 changes: 50 additions & 22 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ pip install pyonmttok

### Interface

```python
import pyonmttok
#### Constructor

```python
tokenizer = pyonmttok.Tokenizer(
mode: str,
bpe_model_path: str = "",
Expand All @@ -43,24 +43,44 @@ tokenizer = pyonmttok.Tokenizer(
segment_numbers: bool = False,
segment_alphabet_change: bool = False,
support_prior_joiners: bool = False,
segment_alphabet: list = [])
segment_alphabet: List[str] = [])
```

# See section "Token API" below for more information about the as_tokens argument.
tokens, features = tokenizer.tokenize(text: str, as_tokens: bool = False)
See the [documentation](../../docs/options.md) for a description of each tokenization option.

text = tokenizer.detokenize(tokens: list, features: list = None)
#### Tokenization

# Function that also returns a dictionary mapping a token index to a range in
# the detokenized text. Set merge_ranges=True to merge consecutive ranges, e.g.
# subwords of the same token in case of subword tokenization.
text, ranges = tokenizer.detokenize_with_ranges(tokens: list, merge_ranges: bool = True)
```python
# By default, tokenize returns the tokens and features.
tokenizer.tokenize(text: str) -> Tuple[List[str], List[List[str]]]

# The as_token_objects flag can alternatively return Token objects (see below).
tokenizer.tokenize(text: str, as_token_objects=True) -> List[pyonmttok.Token]

# File-based APIs
# Tokenize a file.
tokenizer.tokenize_file(input_path: str, output_path: str, num_threads: int = 1)
tokenizer.detokenize_file(input_path: str, output_path: str)
```

See the [documentation](../../docs/options.md) for a description of each option.
#### Detokenization

```python
# The detokenize method converts tokens back to a string.
tokenizer.detokenize(
tokens: Union[List[str], List[pyonmttok.Token]],
features: List[List[str]] = None
) -> str

# The detokenize_with_ranges method also returns a dictionary mapping a token
# index to a range in the detokenized text. Set merge_ranges=True to merge
# consecutive ranges, e.g. subwords of the same token in case of subword tokenization.
tokenizer.detokenize_with_ranges(
tokens: Union[List[str], List[pyonmttok.Token]],
merge_ranges: bool = True
) -> Tuple[str, Dict[int, Pair[int, int]]]

# Detokenize a file.
tokenizer.detokenize_file(input_path: str, output_path: str)
```

## Subword learning

Expand Down Expand Up @@ -121,7 +141,7 @@ learner.ingest(text: str)
learner.ingest_file(path: str)
learner.ingest_token(token: str)

tokenizer = learner.learn(model_path: str, verbose: bool = False)
learner.learn(model_path: str, verbose: bool = False) -> pyonmttok.Tokenizer
```

## Token API
Expand All @@ -132,7 +152,7 @@ The Token API allows to tokenize text into `pyonmttok.Token` objects. This API c

```python
>>> tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
>>> tokens = tokenizer.tokenize("Hello World!", as_tokens=True)
>>> tokens = tokenizer.tokenize("Hello World!", as_token_objects=True)
>>> tokens[-1].surface
'!'
>>> tokenizer.serialize_tokens(tokens)[0]
Expand All @@ -149,16 +169,21 @@ The Token API allows to tokenize text into `pyonmttok.Token` objects. This API c
The `pyonmttok.Token` class has the following attributes:

* `surface`: a string, the token value
* `type`: a `pyonmttok.TokenType` value, the type of the token
* `join_left`: a boolean, whether the token should be joined to the token on the left or not
* `join_right`: a boolean, whether the token should be joined to the token on the right or not
* `spacer`: a boolean, whether the token is a spacer
* `preserve`: a boolean, whether joiners and spacers can be attached to this token or not
* `features`: a list of string, the features attached to the token
* `casing`: a `pyonmttok.Casing` value, the casing of the token
* `begin_case_region`: a `pyonmttok.Casing` value, the casing region that the token opens
* `end_case_region`: a `pyonmttok.Casing` value, the casing region that the token closes
* `spacer`: a boolean, whether the token is prefixed by a SentencePiece spacer or not (only set when using SentencePiece)
* `casing`: a `pyonmttok.Casing` value, the casing of the token (only set when tokenizing with `case_feature` or `case_markup`)

The `pyonmttok.TokenType` enumeration is used to identify tokens that were split by a subword tokenization. The enumeration has the following values:

* `TokenType.WORD`
* `TokenType.LEADING_SUBWORD`
* `TokenType.TRAILING_SUBWORD`

The `pyonmttok.Casing` enumeration can take the following values:
The `pyonmttok.Casing` enumeration is used to identify the original casing of a token that was lowercased by the `case_feature` or `case_markup` tokenization options. The enumeration has the following values:

* `Casing.LOWERCASE`
* `Casing.UPPERCASE`
Expand All @@ -170,8 +195,11 @@ The `Tokenizer` instances provide methods to serialize or deserialize `Token` ob

```python
# Serialize Token objects to strings that can be saved on disk.
tokens, features = tokenizer.serialize_tokens(tokens: list)
tokenizer.serialize_tokens(tokens: List[pyonmttok.Token]) -> Tuple[List[str], List[List[str]]]

# Deserialize strings into Token objects.
tokens = tokenizer.deserialize_tokens(tokens: list, features: list = None)
tokenizer.deserialize_tokens(
tokens: List[str],
features: List[List[str]] = None
) -> List[pyonmttok.Token]
```
61 changes: 56 additions & 5 deletions bindings/python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_learner_with_invalid_files(tmpdir, learner):
def test_token_api():
tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True, case_markup=True)

tokens = tokenizer.tokenize("Hello WORLD!", as_tokens=True)
text = "Hello WORLD!"
tokens = tokenizer.tokenize(text, as_token_objects=True)
assert len(tokens) == 3
for token in tokens:
assert isinstance(token, pyonmttok.Token)
Expand All @@ -196,12 +197,13 @@ def test_token_api():
assert tokens[0].casing == pyonmttok.Casing.CAPITALIZED
assert tokens[1].surface == "world"
assert tokens[1].casing == pyonmttok.Casing.UPPERCASE
assert tokens[1].begin_case_region == pyonmttok.Casing.UPPERCASE
assert tokens[1].end_case_region == pyonmttok.Casing.UPPERCASE
assert tokens[2].surface == "!"
assert tokens[2].join_left
assert tokenizer.detokenize(tokens) == "Hello WORLD!"
assert tokenizer.detokenize_with_ranges(tokens)[0] == "Hello WORLD!"

assert tokenizer.detokenize(tokens) == text
detokenized_text, ranges = tokenizer.detokenize_with_ranges(tokens)
assert detokenized_text == text
assert list(sorted(ranges.keys())) == [0, 1, 2]

serialized_tokens, _ = tokenizer.serialize_tokens(tokens)
assert serialized_tokens == [
Expand All @@ -219,3 +221,52 @@ def test_token_api():
tokens[0].casing = pyonmttok.Casing.LOWERCASE
tokens[2].join_left = False
assert tokenizer.detokenize(tokens) == "toto WORLD !"

def test_token_api_with_subword():
tokenizer = pyonmttok.Tokenizer(
"conservative",
case_markup=True,
joiner_annotate=True,
bpe_model_path=os.path.join(_DATA_DIR, "bpe-models", "codes_suffix_case_insensitive.fr"))

text = "BONJOUR MONDE"
def _check_subword(tokens):
assert len(tokens) == 5
assert tokens[0].type == pyonmttok.TokenType.LEADING_SUBWORD # bon
assert tokens[1].type == pyonmttok.TokenType.TRAILING_SUBWORD # j
assert tokens[2].type == pyonmttok.TokenType.TRAILING_SUBWORD # our
assert tokens[3].type == pyonmttok.TokenType.LEADING_SUBWORD # mon
assert tokens[4].type == pyonmttok.TokenType.TRAILING_SUBWORD # de

tokens = tokenizer.tokenize(text, as_token_objects=True)
_check_subword(tokens)
serialized_tokens, _ = tokenizer.serialize_tokens(tokens)

# Deserialization should not loose subword information.
tokens = tokenizer.deserialize_tokens(serialized_tokens)
_check_subword(tokens)
assert serialized_tokens == tokenizer.serialize_tokens(tokens)[0]

def test_token_api_features():
tokenizer = pyonmttok.Tokenizer("space")
tokens = tokenizer.tokenize("a b", as_token_objects=True)
assert tokens[0].features == []
assert tokens[1].features == []

tokens = tokenizer.tokenize("a│1 b│2", as_token_objects=True)
assert tokens[0].features == ["1"]
assert tokens[1].features == ["2"]

tokens, features = tokenizer.serialize_tokens(tokens)
assert tokens == ["a", "b"]
assert features == [["1", "2"]]

# Case features should be deserialized into the casing attribute, not as features.
tokenizer = pyonmttok.Tokenizer("space", case_feature=True)
tokens = tokenizer.deserialize_tokens(["hello", "world"], features=[["C", "U"]])
assert tokens[0].surface == "hello"
assert tokens[0].casing == pyonmttok.Casing.CAPITALIZED
assert tokens[0].features == []
assert tokens[1].surface == "world"
assert tokens[1].casing == pyonmttok.Casing.UPPERCASE
assert tokens[1].features == []
30 changes: 25 additions & 5 deletions include/onmt/CaseModifier.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
#pragma once

#include <string>
#include <vector>

#include "onmt/opennmttokenizer_export.h"

namespace onmt
{

class Token;

// TODO: this should not be a class.
class OPENNMTTOKENIZER_EXPORT CaseModifier
{
public:
enum class Type
{
None,
Lowercase,
Uppercase,
Mixed,
Capitalized,
None
};

static std::pair<std::string, Type> extract_case_type(const std::string& token);
Expand All @@ -30,17 +33,34 @@ namespace onmt

enum class Markup
{
None,
Modifier,
RegionBegin,
RegionEnd,
None
};

static Markup get_case_markup(const std::string& str);
static std::string generate_case_markup(Markup markup, Type type);
static Type get_case_modifier_from_markup(const std::string& markup);
static std::string generate_case_markup(Type type);
static std::string generate_case_markup_begin(Type type);
static std::string generate_case_markup_end(Type type);

struct TokenMarkup
{
TokenMarkup(Markup prefix_, Markup suffix_, Type type_)
: prefix(prefix_)
, suffix(suffix_)
, type(type_)
{
}
Markup prefix;
Markup suffix;
Type type;
};

// In "soft" mode, this function tries to minimize the number of uppercase regions by possibly
// including case invariant characters (numbers, symbols, etc.) in uppercase regions.
static std::vector<TokenMarkup>
get_case_markups(const std::vector<Token>& tokens, const bool soft = true);

};

}
Loading