Skip to content

Migrate java api: LlamaModule -> LlmModule #9478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ fb_android_library(
srcs = [
"src/main/java/org/pytorch/executorch/LlamaCallback.java",
"src/main/java/org/pytorch/executorch/LlamaModule.java",
"src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java",
"src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java",
],
autoglob = False,
language = "JAVA",
Expand Down
8 changes: 4 additions & 4 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,14 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
} // namespace executorch::extension

#ifdef EXECUTORCH_BUILD_LLAMA_JNI
extern void register_natives_for_llama();
extern void register_natives_for_llm();
#else
// No op if we don't build llama
void register_natives_for_llama() {}
// No op if we don't build LLM
void register_natives_for_llm() {}
#endif
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_llama();
register_natives_for_llm();
});
}
39 changes: 19 additions & 20 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ std::string token_buffer;

namespace executorch_jni {

class ExecuTorchLlamaCallbackJni
: public facebook::jni::JavaClass<ExecuTorchLlamaCallbackJni> {
class ExecuTorchLlmCallbackJni
: public facebook::jni::JavaClass<ExecuTorchLlmCallbackJni> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/LlamaCallback;";
"Lorg/pytorch/executorch/extension/llm/LlmCallback;";

void onResult(std::string result) const {
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");

Expand All @@ -99,7 +99,7 @@ class ExecuTorchLlamaCallbackJni
}

void onStats(const llm::Stats& result) const {
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
static const auto method = cls->getMethod<void(jfloat)>("onStats");
double eval_time =
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
Expand All @@ -111,8 +111,7 @@ class ExecuTorchLlamaCallbackJni
}
};

class ExecuTorchLlamaJni
: public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
private:
friend HybridBase;
int model_type_category_;
Expand All @@ -121,7 +120,7 @@ class ExecuTorchLlamaJni

public:
constexpr static auto kJavaDescriptor =
"Lorg/pytorch/executorch/LlamaModule;";
"Lorg/pytorch/executorch/extension/llm/LlmModule;";

constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
Expand All @@ -142,7 +141,7 @@ class ExecuTorchLlamaJni
data_path);
}

ExecuTorchLlamaJni(
ExecuTorchLlmJni(
jint model_type_category,
facebook::jni::alias_ref<jstring> model_path,
facebook::jni::alias_ref<jstring> tokenizer_path,
Expand Down Expand Up @@ -197,7 +196,7 @@ class ExecuTorchLlamaJni
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
auto image_size = image->size();
Expand Down Expand Up @@ -296,7 +295,7 @@ class ExecuTorchLlamaJni
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::NotSupported);
Expand Down Expand Up @@ -329,22 +328,22 @@ class ExecuTorchLlamaJni

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid),
makeNativeMethod("generate", ExecuTorchLlamaJni::generate),
makeNativeMethod("stop", ExecuTorchLlamaJni::stop),
makeNativeMethod("load", ExecuTorchLlamaJni::load),
makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid),
makeNativeMethod("generate", ExecuTorchLlmJni::generate),
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
"prefillImagesNative", ExecuTorchLlamaJni::prefill_images),
"prefillImagesNative", ExecuTorchLlmJni::prefill_images),
makeNativeMethod(
"prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt),
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
makeNativeMethod(
"generateFromPos", ExecuTorchLlamaJni::generate_from_pos),
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
});
}
};

} // namespace executorch_jni

void register_natives_for_llama() {
executorch_jni::ExecuTorchLlamaJni::registerNatives();
void register_natives_for_llm() {
executorch_jni::ExecuTorchLlmJni::registerNatives();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;
import org.pytorch.executorch.annotations.Experimental;

/**
* Callback interface for Llama model. Users can implement this interface to receive the generated
* tokens and statistics.
*
* <p>Warning: These APIs are experimental and subject to change without notice
* <p>Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmCallback} instead.
*/
@Experimental
@Deprecated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was already tagged as experimental to begin with, we should just delete this altogether with 0.6 release. We don't need to keep for two releases, like we do with stable APIs.

https://pytorch.org/executorch/main/api-life-cycle.html#api-life-cycle

public interface LlamaCallback {
/**
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,45 @@

package org.pytorch.executorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import org.pytorch.executorch.annotations.Experimental;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

/**
* LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to
* generate text from the model.
*
* <p>Warning: These APIs are experimental and subject to change without notice
* <p>Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmModule} instead.
*/
@Experimental
@Deprecated
public class LlamaModule {

public static final int MODEL_TYPE_TEXT = 1;
public static final int MODEL_TYPE_TEXT_VISION = 2;

static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;
private LlmModule mModule;
private static final int DEFAULT_SEQ_LEN = 128;
private static final boolean DEFAULT_ECHO = true;

@DoNotStrip
private static native HybridData initHybrid(
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);

/** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature. */
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
mModule = new LlmModule(modulePath, tokenizerPath, temperature);
}

/**
* Constructs a LLAMA Module for a model with given model path, tokenizer, temperature and data
* path.
*/
public LlamaModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
mModule = new LlmModule(modulePath, tokenizerPath, temperature, dataPath);
}

/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
mModule = new LlmModule(modelType, modulePath, tokenizerPath, temperature);
}

public void resetNative() {
mHybridData.resetNative();
mModule.resetNative();
}

/**
Expand All @@ -70,7 +56,7 @@ public void resetNative() {
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, LlamaCallback llamaCallback) {
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
}

/**
Expand Down Expand Up @@ -119,16 +105,35 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, bool
* @param llamaCallback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
@DoNotStrip
public native int generate(
public int generate(
int[] image,
int width,
int height,
int channels,
String prompt,
int seqLen,
LlamaCallback llamaCallback,
boolean echo);
boolean echo) {
return mModule.generate(
image,
width,
height,
channels,
prompt,
seqLen,
new LlmCallback() {
@Override
public void onResult(String result) {
llamaCallback.onResult(result);
}

@Override
public void onStats(float tps) {
llamaCallback.onStats(tps);
}
},
echo);
}

/**
* Prefill an LLaVA Module with the given images input.
Expand All @@ -142,17 +147,9 @@ public native int generate(
* @throws RuntimeException if the prefill failed
*/
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
return mModule.prefillImages(image, width, height, channels, startPos);
}

// returns a tuple of (status, updated startPos)
private native long[] prefillImagesNative(
int[] image, int width, int height, int channels, long startPos);

/**
* Prefill an LLaVA Module with the given text input.
*
Expand All @@ -165,16 +162,9 @@ private native long[] prefillImagesNative(
* @throws RuntimeException if the prefill failed
*/
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
return mModule.prefillPrompt(prompt, startPos, bos, eos);
}

// returns a tuple of (status, updated startPos)
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);

/**
* Generate tokens from the given prompt, starting from the given position.
*
Expand All @@ -185,14 +175,33 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
public native int generateFromPos(
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);
public int generateFromPos(
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo) {
return mModule.generateFromPos(
prompt,
seqLen,
startPos,
new LlmCallback() {
@Override
public void onResult(String result) {
callback.onResult(result);
}

@Override
public void onStats(float tps) {
callback.onStats(tps);
}
},
echo);
}

/** Stop current generate() before it finishes. */
@DoNotStrip
public native void stop();
public void stop() {
mModule.stop();
}

/** Force loading the module. Otherwise the model is loaded during first generate(). */
@DoNotStrip
public native int load();
public int load() {
return mModule.load();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch.extension.llm;

import com.facebook.jni.annotations.DoNotStrip;
import org.pytorch.executorch.annotations.Experimental;

/**
* Callback interface for Llama model. Users can implement this interface to receive the generated
* tokens and statistics.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public interface LlmCallback {
/**
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
* until generate() finishes.
*
* @param result Last generated token
*/
@DoNotStrip
public void onResult(String result);

/**
* Called when the statistics for the generate() is available.
*
* @param tps Tokens/second for generated tokens.
*/
@DoNotStrip
public void onStats(float tps);
}
Loading
Loading