16
16
#include < executorch/extension/llm/runner/util.h>
17
17
18
18
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19
- #include < pytorch/tokenizers/llama2c_tokenizer.h>
20
19
#include < pytorch/tokenizers/hf_tokenizer.h>
20
+ #include < pytorch/tokenizers/llama2c_tokenizer.h>
21
21
22
22
namespace example {
23
23
@@ -35,6 +35,41 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
35
35
static constexpr auto kVocabSize = " get_vocab_size" ;
36
36
static constexpr auto kUseKVCache = " use_kv_cache" ;
37
37
static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
38
+
39
+ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
40
+ const std::string& tokenizer_path) {
41
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr ;
42
+ ::tokenizers::Error err;
43
+
44
+ // First try to load as a json tokenizer.
45
+ {
46
+ auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
47
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
48
+ ET_LOG (Info, " Loaded json tokenizer" );
49
+ return tokenizer;
50
+ }
51
+ }
52
+
53
+ // Try to load as tiktoken tokenizer.
54
+ {
55
+ auto tokenizer = get_tiktoken_for_llama ();
56
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
57
+ ET_LOG (Info, " Loaded TikToken tokenizer" );
58
+ return tokenizer;
59
+ }
60
+ }
61
+
62
+ // Try to load as BPE tokenizer.
63
+ {
64
+ auto tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
65
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
66
+ ET_LOG (Info, " Loaded BPE tokenizer" );
67
+ return tokenizer;
68
+ }
69
+ }
70
+
71
+ return nullptr ;
72
+ }
38
73
} // namespace
39
74
40
75
Runner::Runner (
@@ -76,35 +111,15 @@ Error Runner::load() {
76
111
return Error::Ok;
77
112
}
78
113
ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
114
+
79
115
// Load tokenizer.
80
- tokenizer_ = nullptr ;
81
- // Check if tokenizer_path_ ends with ".json".
82
- if (tokenizer_path_.size () >= 5 &&
83
-
84
- tokenizer_path_.compare (tokenizer_path_.size () - 5 , 5 , " .json" ) == 0 ) {
85
- tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
86
- ET_LOG (Info, " Loading json tokenizer" );
87
- tokenizer_->load (tokenizer_path_);
116
+ tokenizer_ = load_tokenizer (tokenizer_path_);
117
+ if (tokenizer_ == nullptr ) {
88
118
ET_LOG (
89
- Info, " Loaded tokenizer %s as HF tokenizer" , tokenizer_path_.c_str ());
90
- } else {
91
- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
92
- tokenizer_ = get_tiktoken_for_llama ();
93
- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
94
- // fallback to BPE tokenizer.
95
- if (err != ::tokenizers::Error::Ok) {
96
- ET_LOG (
97
- Info,
98
- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
99
- tokenizer_path_.c_str ());
100
- tokenizer_.reset ();
101
- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102
- err = tokenizer_->load (tokenizer_path_);
103
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
104
- err,
105
- " Failed to load %s as a llama2.c tokenizer artifact" ,
106
- tokenizer_path_.c_str ());
107
- }
119
+ Error,
120
+ " Failed to load %s as a llama2.c tokenizer artifact" ,
121
+ tokenizer_path_.c_str ());
122
+ return ::executorch::runtime::Error::InvalidArgument;
108
123
}
109
124
110
125
ET_LOG (Info, " Reading metadata from model" );
0 commit comments