Skip to content

Commit 63b2ed7

Browse files
committed
Scott pr review
1 parent 0883f32 commit 63b2ed7

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

examples/models/llama/runner/runner.cpp

+43-28
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include <executorch/extension/llm/runner/util.h>
1717

1818
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19-
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2019
#include <pytorch/tokenizers/hf_tokenizer.h>
20+
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2121

2222
namespace example {
2323

@@ -35,6 +35,41 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3535
static constexpr auto kVocabSize = "get_vocab_size";
3636
static constexpr auto kUseKVCache = "use_kv_cache";
3737
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
38+
39+
std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
40+
const std::string& tokenizer_path) {
41+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr;
42+
::tokenizers::Error err;
43+
44+
// First try to load as a json tokenizer.
45+
{
46+
auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
47+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
48+
ET_LOG(Info, "Loaded json tokenizer");
49+
return tokenizer;
50+
}
51+
}
52+
53+
// Try to load as tiktoken tokenizer.
54+
{
55+
auto tokenizer = get_tiktoken_for_llama();
56+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
57+
ET_LOG(Info, "Loaded TikToken tokenizer");
58+
return tokenizer;
59+
}
60+
}
61+
62+
// Try to load as BPE tokenizer.
63+
{
64+
auto tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
65+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
66+
ET_LOG(Info, "Loaded BPE tokenizer");
67+
return tokenizer;
68+
}
69+
}
70+
71+
return nullptr;
72+
}
3873
} // namespace
3974

4075
Runner::Runner(
@@ -76,35 +111,15 @@ Error Runner::load() {
76111
return Error::Ok;
77112
}
78113
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
114+
79115
// Load tokenizer.
80-
tokenizer_ = nullptr;
81-
// Check if tokenizer_path_ ends with ".json".
82-
if (tokenizer_path_.size() >= 5 &&
83-
84-
tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) {
85-
tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
86-
ET_LOG(Info, "Loading json tokenizer");
87-
tokenizer_->load(tokenizer_path_);
116+
tokenizer_ = load_tokenizer(tokenizer_path_);
117+
if (tokenizer_ == nullptr) {
88118
ET_LOG(
89-
Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str());
90-
} else {
91-
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
92-
tokenizer_ = get_tiktoken_for_llama();
93-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
94-
// fallback to BPE tokenizer.
95-
if (err != ::tokenizers::Error::Ok) {
96-
ET_LOG(
97-
Info,
98-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
99-
tokenizer_path_.c_str());
100-
tokenizer_.reset();
101-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102-
err = tokenizer_->load(tokenizer_path_);
103-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
104-
err,
105-
"Failed to load %s as a llama2.c tokenizer artifact",
106-
tokenizer_path_.c_str());
107-
}
119+
Error,
120+
"Failed to load %s as a llama2.c tokenizer artifact",
121+
tokenizer_path_.c_str());
122+
return ::executorch::runtime::Error::InvalidArgument;
108123
}
109124

110125
ET_LOG(Info, "Reading metadata from model");

0 commit comments

Comments
 (0)