Skip to content

Commit

Permalink
Add a decoder for Unigram tokenizer and unify some classes among toke…
Browse files Browse the repository at this point in the history
…nizers (#816)

* rename and formalize the file names

* add the decoder impl

* fix a typo
  • Loading branch information
wenbingl authored Sep 25, 2024
1 parent 6b94f4d commit f204a4c
Show file tree
Hide file tree
Showing 21 changed files with 357 additions and 236 deletions.
2 changes: 1 addition & 1 deletion operators/tokenizer/bpe_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include "ustring.h"
#include "narrow.h"
#include "tokjson_types.h"
#include "tokenizer_common.h"

struct KernelBpeDecoder {
public:
Expand Down
24 changes: 11 additions & 13 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "file_sys.h"

#include "bpe_kernels.h"
#include "bpe_jsoncfg.hpp"
#include "bpe_tokenizer.hpp"
#include <limits>
#include <optional>

#include "base64.h"

#include <optional>
#include <limits>
#include "file_sys.h"
#include "bpe_kernels.h"
#include "tokenizer_jsconfig.hpp"
#include "bpe_tokenizer_model.hpp"

using namespace ort_extensions;

Expand Down Expand Up @@ -673,11 +671,11 @@ struct VectorEqual {
}
};

OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
bpe::AddedToken added_token;
AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
Expand Down Expand Up @@ -721,7 +719,7 @@ bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
return false;
}

void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config) {
void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
if (!config.add_bos_token_ && !config.bos_token_.empty()) {
auto post_processor = tok_json.find("post_processor");
if (post_processor != tok_json.end()) {
Expand All @@ -736,7 +734,7 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
}
}

OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
Expand Down Expand Up @@ -785,7 +783,7 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
return status;
}

OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config) {
OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
Expand Down
17 changes: 6 additions & 11 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
#include <vector>
#include <functional>

#include "ortx_tokenizer.h"
#include "ext_status.h"
#include "op_def_struct.h"
#include "nlohmann/json_fwd.hpp"
#include "tokjson_types.h"
#include "ustring.h"
#include "tokenizer_common.h"


struct BpeModelConf {
Expand Down Expand Up @@ -116,8 +111,8 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : public KernelBpeTokenizer {
public:
JsonFastTokenizer();
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus LoadTikTokenBase64(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Load(const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
Expand All @@ -133,9 +128,9 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
// template functions to avoid including the huge json header file
bool CheckForSpmModel(const json& tok_json);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::bpe::TokenJsonConfig& config);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);

BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
std::vector<ort_extensions::AddedToken> added_tokens_;
};
17 changes: 5 additions & 12 deletions operators/tokenizer/bpe_streaming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@

#include "bpe_kernels.h"
#include "bpe_decoder.hpp"
#include "bpe_jsoncfg.hpp"
#include "bpe_tokenizer.hpp"

namespace ort_extensions {
struct BPEDecoderState {
bool f_special_last{};
std::string incomplete_utf8_;
};
} // namespace ort_extensions
#include "tokenizer_jsconfig.hpp"
#include "bpe_tokenizer_model.hpp"

class BpeStreamingDecoder : public KernelBpeDecoder {
public:
BpeStreamingDecoder() = default;
~BpeStreamingDecoder() override = default;

using BPEDecoderState = ort_extensions::BPEDecoderState;
using BPEDecoderState = ort_extensions::TokenizerDecodingState;

// shared the data between the encoder and decoder
OrtxStatus Load(
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> ptr_config,
std::shared_ptr<ort_extensions::TokenJsonConfig const> ptr_config,
const JsonFastTokenizer& encoder) {
const auto& tok_config = *ptr_config;
bos_token_ = tok_config.bos_token_;
Expand Down Expand Up @@ -258,5 +251,5 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
extTokenId_t eos_token_id_{0};
bool add_dummy_prefix_ = false;
bool spm_model_{};
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig const> tok_config_;
std::shared_ptr<ort_extensions::TokenJsonConfig const> tok_config_;
};
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "nlohmann/json.hpp"
#include "bpe_utils.hpp"
#include "trietree.hpp"
#include "tokjson_types.h"
#include "tokenizer_common.h"

namespace ort_extensions {

Expand Down Expand Up @@ -249,7 +249,7 @@ class BpeModel {
return {};
}

OrtxStatus LoadAddedTokens(const std::vector<bpe::AddedToken>& added_tokens) {
OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}
Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/sentencepiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "sentencepiece_processor.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_tokenizer.hpp"
#include "sentencepiece_tokenizer.h"
#include "string_tensor.h"
#include "base64.h"
#include "narrow.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

#include <string>

#include "ortx_tokenizer.h"
#include "ext_status.h"
#include "op_def_struct.h"
#include "nlohmann/json_fwd.hpp"

#include "ustring.h"


namespace ort_extensions {
class BpeModel;

namespace bpe {

struct AddedToken final {
uint32_t id_{};
std::string token_type_;
Expand All @@ -23,7 +29,10 @@ struct AddedToken final {

class TokenJsonConfig; // forward declaration

} // namespace bpe
struct TokenizerDecodingState {
bool f_special_last{};
std::string incomplete_utf8_;
};

constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
} // namespace ort_extensions
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

#pragma once

#include "ocos.h"
#include "file_sys.h"
#include "nlohmann/json.hpp"

#include "tokjson_types.h"
#include "tokenizer_common.h"

namespace ort_extensions::bpe {
namespace ort_extensions {

// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
class TokenJsonConfig final {
public:
static constexpr const char* kDefaultVocabFile = "tokenizer.json";
Expand All @@ -26,6 +26,7 @@ class TokenJsonConfig final {
return OrtxStatus(kOrtxErrorInvalidArgument, "json_path is empty.");
}


ortx::path tok_dir(json_path);
ortx::path vocab_path(json_path);
ortx::path tok_path_obj(json_path);
Expand Down Expand Up @@ -122,4 +123,4 @@ class TokenJsonConfig final {
std::string module_path_;
};

} // namespace ort_extensions::bpe
} // namespace ort_extensions
6 changes: 3 additions & 3 deletions operators/tokenizer/tokenizers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
#include "ocos.h"

#ifdef ENABLE_GPT2_TOKENIZER
#include "bpe_tokenizer.hpp"
#include "bpe_kernels.h"
#include "bpe_tokenizer_model.hpp"
#include "bpe_decoder.hpp"
#endif

#ifdef ENABLE_SPM_TOKENIZER
#include "sentencepiece_tokenizer.hpp"
#include "sentencepiece_tokenizer.h"
#include "sentencepiece_decoder.hpp"
#endif

#ifdef ENABLE_WORDPIECE_TOKENIZER
#include "wordpiece_tokenizer.hpp"
#include "wordpiece_tokenizer.h"
#endif

#ifdef ENABLE_BLINGFIRE
Expand Down
5 changes: 3 additions & 2 deletions operators/tokenizer/trie_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <charconv>
#include <optional>

#include "unescape.h"
#include "unescape.hpp"
#include "trietree.hpp"

// This Trie Tree is C++ implementation of
Expand All @@ -40,6 +40,7 @@ class TrieTokenizer {
private:
std::map<int, std::string> idx2token;
RWKVTrieTree root;
using UnescapeUtils = ort_extensions::UnescapeUtils;

public:
TrieTokenizer(const std::string& text_tokens) {
Expand All @@ -62,7 +63,7 @@ class TrieTokenizer {
std::string raw = line.substr(line.find(' ') + 1, line.rfind(' ') - line.find(' ') - 1);
std::string x;
int key_length = 0;
if (ort_extensions::UnquoteString(raw, x)) {
if (UnescapeUtils::UnquoteString(raw, x)) {
std::from_chars(line.data() + r_ws + 1, line.data() + line.size(), key_length);
}
if (x.length() != key_length) {
Expand Down
Loading

0 comments on commit f204a4c

Please sign in to comment.