Skip to content

Commit 0125a58

Browse files
lucylqfacebook-github-bot
authored andcommitted
Update jni runner (#8978)
Summary: Add data map to JNI layer and LlamaModule ctor. Reviewed By: cmodi-meta, kirklandsign Differential Revision: D70597652
1 parent 352416e commit 0125a58

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

extension/android/jni/jni_layer_llama.cpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,22 @@ class ExecuTorchLlamaJni
132132
jint model_type_category,
133133
facebook::jni::alias_ref<jstring> model_path,
134134
facebook::jni::alias_ref<jstring> tokenizer_path,
135-
jfloat temperature) {
135+
jfloat temperature,
136+
facebook::jni::alias_ref<jstring> data_path) {
136137
return makeCxxInstance(
137-
model_type_category, model_path, tokenizer_path, temperature);
138+
model_type_category,
139+
model_path,
140+
tokenizer_path,
141+
temperature,
142+
data_path);
138143
}
139144

140145
ExecuTorchLlamaJni(
141146
jint model_type_category,
142147
facebook::jni::alias_ref<jstring> model_path,
143148
facebook::jni::alias_ref<jstring> tokenizer_path,
144-
jfloat temperature) {
149+
jfloat temperature,
150+
facebook::jni::alias_ref<jstring> data_path = nullptr) {
145151
#if defined(ET_USE_THREADPOOL)
146152
// Reserve 1 thread for the main thread.
147153
uint32_t num_performant_cores =
@@ -160,10 +166,18 @@ class ExecuTorchLlamaJni
160166
tokenizer_path->toStdString().c_str(),
161167
temperature);
162168
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163-
runner_ = std::make_unique<example::Runner>(
164-
model_path->toStdString().c_str(),
165-
tokenizer_path->toStdString().c_str(),
166-
temperature);
169+
if (data_path != nullptr) {
170+
runner_ = std::make_unique<example::Runner>(
171+
model_path->toStdString().c_str(),
172+
tokenizer_path->toStdString().c_str(),
173+
temperature,
174+
data_path->toStdString().c_str());
175+
} else {
176+
runner_ = std::make_unique<example::Runner>(
177+
model_path->toStdString().c_str(),
178+
tokenizer_path->toStdString().c_str(),
179+
temperature);
180+
}
167181
#if defined(EXECUTORCH_BUILD_MEDIATEK)
168182
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169183
runner_ = std::make_unique<MTKLlamaRunner>(

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,24 @@ public class LlamaModule {
3939

4040
@DoNotStrip
4141
private static native HybridData initHybrid(
42-
int modelType, String modulePath, String tokenizerPath, float temperature);
42+
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
4343

44-
/** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
44+
/** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature. */
4545
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
46-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature);
46+
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
47+
}
48+
49+
/**
50+
* Constructs a LLAMA Module for a model with given model path, tokenizer, temperature and data
51+
* path.
52+
*/
53+
public LlamaModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
54+
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
4755
}
4856

4957
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
5058
public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
51-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature);
59+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
5260
}
5361

5462
public void resetNative() {

0 commit comments

Comments
 (0)