Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,15 @@ private native int appendRawAudioInput(
*/
@Experimental
public long prefillPrompt(String prompt) {
int nativeResult = appendTextInput(prompt);
int nativeResult = prefillTextInput(prompt);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

// returns status
private native int appendTextInput(String prompt);
private native int prefillTextInput(String prompt);

/**
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
Expand Down
10 changes: 8 additions & 2 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {

// Returns status_code
// Contract is valid within an AAR (JNI + corresponding Java code)
jint append_text_input(facebook::jni::alias_ref<jstring> prompt) {
jint prefill_text_input(facebook::jni::alias_ref<jstring> prompt) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM && runner_) {
executorch::extension::llm::GenerationConfig config;
auto err = runner_->prefill(prompt->toStdString(), config);
return static_cast<jint>(err);
}
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return 0;
}
Expand Down Expand Up @@ -391,6 +396,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
if (multi_modal_runner_ != nullptr) {
multi_modal_runner_->reset();
}
prefill_inputs_.clear();
}

jint load() {
Expand Down Expand Up @@ -438,7 +444,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod(
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
makeNativeMethod(
"appendTextInput", ExecuTorchLlmJni::append_text_input),
"prefillTextInput", ExecuTorchLlmJni::prefill_text_input),
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
});
}
Expand Down
19 changes: 19 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ class ET_EXPERIMENTAL IRunner {
*/
virtual void stop() = 0;

/**
* Prefill the model with the given prompt without generating tokens.
*
* This populates the KV cache with the prompt tokens, allowing subsequent
* generate() calls to continue from the prefilled state. Useful for
* reloading chat history.
*
* @param prompt The text to prefill
* @param config Generation configuration (num_bos, num_eos used for
* encoding)
* @return Error::Ok if successful, Error::NotSupported if the runner does
* not support standalone prefill
*/
virtual runtime::Error prefill(
const std::string& prompt,
const GenerationConfig& config) {
return runtime::Error::NotSupported;
}

/**
* Force remove prefilled tokens and reset KV cache start position
*
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
*/
::executorch::runtime::Error prefill(
const std::string& prompt,
const GenerationConfig& config);
const GenerationConfig& config) override;

/**
* @brief Warms up the model with a sample prompt
Expand Down
Loading