@@ -132,16 +132,22 @@ class ExecuTorchLlamaJni
132
132
jint model_type_category,
133
133
facebook::jni::alias_ref<jstring> model_path,
134
134
facebook::jni::alias_ref<jstring> tokenizer_path,
135
- jfloat temperature) {
135
+ jfloat temperature,
136
+ facebook::jni::alias_ref<jstring> data_path) {
136
137
return makeCxxInstance (
137
- model_type_category, model_path, tokenizer_path, temperature);
138
+ model_type_category,
139
+ model_path,
140
+ tokenizer_path,
141
+ temperature,
142
+ data_path);
138
143
}
139
144
140
145
ExecuTorchLlamaJni (
141
146
jint model_type_category,
142
147
facebook::jni::alias_ref<jstring> model_path,
143
148
facebook::jni::alias_ref<jstring> tokenizer_path,
144
- jfloat temperature) {
149
+ jfloat temperature,
150
+ facebook::jni::alias_ref<jstring> data_path = nullptr ) {
145
151
#if defined(ET_USE_THREADPOOL)
146
152
// Reserve 1 thread for the main thread.
147
153
uint32_t num_performant_cores =
@@ -160,10 +166,18 @@ class ExecuTorchLlamaJni
160
166
tokenizer_path->toStdString ().c_str (),
161
167
temperature);
162
168
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163
- runner_ = std::make_unique<example::Runner>(
164
- model_path->toStdString ().c_str (),
165
- tokenizer_path->toStdString ().c_str (),
166
- temperature);
169
+ if (data_path != nullptr ) {
170
+ runner_ = std::make_unique<example::Runner>(
171
+ model_path->toStdString ().c_str (),
172
+ tokenizer_path->toStdString ().c_str (),
173
+ temperature,
174
+ data_path->toStdString ().c_str ());
175
+ } else {
176
+ runner_ = std::make_unique<example::Runner>(
177
+ model_path->toStdString ().c_str (),
178
+ tokenizer_path->toStdString ().c_str (),
179
+ temperature);
180
+ }
167
181
#if defined(EXECUTORCH_BUILD_MEDIATEK)
168
182
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169
183
runner_ = std::make_unique<MTKLlamaRunner>(
0 commit comments