Skip to content

Feat: add multimodal - poc #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
668f1eb
feat: add multimodal
a-ghorbani May 11, 2025
45801aa
feat: include image processing in doCompletion
a-ghorbani May 12, 2025
2b7372c
Update cpp/rn-llama.cpp
a-ghorbani May 12, 2025
aae6515
Update cpp/rn-llama.cpp
a-ghorbani May 12, 2025
3fa3502
fix: calculate token counts and embd correctly for multimodal path
a-ghorbani May 14, 2025
2214a97
chore: remove debug logger
a-ghorbani May 14, 2025
e22cd25
fix(ios): correct file paths in iOS CMakeLists.txt
a-ghorbani May 14, 2025
b513b00
feat: add mtmd header files and update include paths in build scripts
a-ghorbani May 15, 2025
0a27900
feat: expose mmproj_use_gpu
a-ghorbani May 15, 2025
620a435
fix: remove use_gpu from initMultimodal api
a-ghorbani May 15, 2025
87ed99a
fix(ts): correct image path
jhen0409 May 17, 2025
90437e7
feat: accept multiple image paths
jhen0409 May 17, 2025
09ca6c5
fix(cpp): embd cache
jhen0409 May 17, 2025
1fb0bda
feat(android): refactor image_paths param
jhen0409 May 17, 2025
788b496
fix(cpp): correct log
jhen0409 May 17, 2025
91a857c
feat(cpp): avoid process chunk for embed
jhen0409 May 17, 2025
664c0b1
feat: support base64 image
jhen0409 May 17, 2025
f2436a5
fix(ios): revert lock change
jhen0409 May 17, 2025
c4038ef
fix(ts): codegen
jhen0409 May 17, 2025
a2e1966
feat(ts): refactor getFormattedChat
jhen0409 May 17, 2025
93aba5d
chore(example): cleanup
jhen0409 May 17, 2025
bd700a0
fix(example): limit image selection to supported formats only
a-ghorbani May 18, 2025
70dfb3c
chore: clean up some logs
a-ghorbani May 18, 2025
3ea0cc7
fix: add check to prevent context overflow when processing tokens in …
a-ghorbani May 18, 2025
7978f08
feat: remove mmproj param in initContext
jhen0409 May 19, 2025
6efc96c
feat: move mmproj_use_gpu to initMultimodal
jhen0409 May 19, 2025
52068b5
fix: remove mmproj_use_gpu in NativeContextParams
jhen0409 May 19, 2025
4c34cb7
feat: add releaseMultimodal method
jhen0409 May 19, 2025
be1bfe3
feat(cpp, ios): avoid mtmd/clip in framework headers
jhen0409 May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)
include_directories(
${RNLLAMA_LIB_DIR}
${RNLLAMA_LIB_DIR}/ggml-cpu
${RNLLAMA_LIB_DIR}/tools/mtmd
)

set(
Expand All @@ -34,6 +35,14 @@ set(
${RNLLAMA_LIB_DIR}/gguf.cpp
${RNLLAMA_LIB_DIR}/log.cpp
${RNLLAMA_LIB_DIR}/llama-impl.cpp
# Multimodal support
${RNLLAMA_LIB_DIR}/tools/mtmd/mtmd.cpp
${RNLLAMA_LIB_DIR}/tools/mtmd/mtmd.h
${RNLLAMA_LIB_DIR}/tools/mtmd/clip.cpp
${RNLLAMA_LIB_DIR}/tools/mtmd/clip.h
${RNLLAMA_LIB_DIR}/tools/mtmd/clip-impl.h
${RNLLAMA_LIB_DIR}/tools/mtmd/mtmd-helper.cpp
${RNLLAMA_LIB_DIR}/tools/mtmd/stb_image.h
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
Expand Down
27 changes: 27 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("top_n_sigma") ? (float) params.getDouble("top_n_sigma") : -1.0f,
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
// String[] image_paths
params.hasKey("image_paths") ? params.getArray("image_paths").toArrayList().toArray(new String[0]) : new String[0],
// PartialCompletionCallback partial_completion_callback
new PartialCompletionCallback(
this,
Expand Down Expand Up @@ -379,6 +381,27 @@ public WritableArray getLoadedLoraAdapters() {
return getLoadedLoraAdapters(this.context);
}

public boolean initMultimodal(ReadableMap params) {
String mmprojPath = params.getString("path");
boolean mmprojUseGpu = params.hasKey("use_gpu") ? params.getBoolean("use_gpu") : true;
if (mmprojPath == null || mmprojPath.isEmpty()) {
throw new IllegalArgumentException("mmproj_path is empty");
}
File file = new File(mmprojPath);
if (!file.exists()) {
throw new IllegalArgumentException("mmproj file does not exist: " + mmprojPath);
}
return initMultimodal(this.context, mmprojPath, mmprojUseGpu);
}

public boolean isMultimodalEnabled() {
return isMultimodalEnabled(this.context);
}

public void releaseMultimodal() {
releaseMultimodal(this.context);
}

public void release() {
freeContext(context);
}
Expand Down Expand Up @@ -497,6 +520,8 @@ protected static native long initContext(
boolean ctx_shift,
LoadProgressCallback load_progress_callback
);
protected static native boolean initMultimodal(long contextPtr, String mmproj_path, boolean MMPROJ_USE_GPU);
protected static native boolean isMultimodalEnabled(long contextPtr);
protected static native void interruptLoad(long contextPtr);
protected static native WritableMap loadModelDetails(
long contextPtr
Expand Down Expand Up @@ -560,6 +585,7 @@ protected static native WritableMap doCompletion(
int dry_penalty_last_n,
float top_n_sigma,
String[] dry_sequence_breakers,
String[] image_paths,
PartialCompletionCallback partial_completion_callback
);
protected static native void stopCompletion(long contextPtr);
Expand All @@ -579,4 +605,5 @@ protected static native WritableMap embedding(
protected static native void freeContext(long contextPtr);
protected static native void setupLog(NativeLogCallback logCallback);
protected static native void unsetLog();
protected static native void releaseMultimodal(long contextPtr);
}
100 changes: 100 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,106 @@ protected void onPostExecute(ReadableArray result) {
tasks.put(task, "getLoadedLoraAdapters-" + contextId);
}

public void initMultimodal(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Boolean>() {
private Exception exception;

@Override
protected Boolean doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
return context.initMultimodal(params);
} catch (Exception e) {
exception = e;
}
return false;
}

@Override
protected void onPostExecute(Boolean result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "initMultimodal-" + contextId);
}

public void isMultimodalEnabled(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Boolean>() {
private Exception exception;

@Override
protected Boolean doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.isMultimodalEnabled();
} catch (Exception e) {
exception = e;
}
return false;
}

@Override
protected void onPostExecute(Boolean result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "isMultimodalEnabled" + contextId);
}

@ReactMethod
public void releaseMultimodal(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
context.releaseMultimodal();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "releaseMultimodal" + id);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
80 changes: 76 additions & 4 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
jint dry_penalty_last_n,
jfloat top_n_sigma,
jobjectArray dry_sequence_breakers,
jobjectArray image_paths,
jobject partial_completion_callback
) {
UNUSED(thiz);
Expand All @@ -694,8 +695,32 @@ Java_com_rnllama_LlamaContext_doCompletion(

//llama_reset_timings(llama->ctx);

auto prompt_chars = env->GetStringUTFChars(prompt, nullptr);
const char *prompt_chars = env->GetStringUTFChars(prompt, nullptr);

// Set the prompt parameter
llama->params.prompt = prompt_chars;

// Process image paths if provided
std::vector<std::string> image_paths_vector;

jint image_paths_size = env->GetArrayLength(image_paths);
if (image_paths_size > 0) {
// Check if multimodal is enabled
if (!llama->isMultimodalEnabled()) {
auto result = createWriteableMap(env);
putString(env, result, "error", "Multimodal support not enabled. Call initMultimodal first.");
env->ReleaseStringUTFChars(prompt, prompt_chars);
return reinterpret_cast<jobject>(result);
}

for (jint i = 0; i < image_paths_size; i++) {
jstring image_path = (jstring) env->GetObjectArrayElement(image_paths, i);
const char *image_path_chars = env->GetStringUTFChars(image_path, nullptr);
image_paths_vector.push_back(image_path_chars);
env->ReleaseStringUTFChars(image_path, image_path_chars);
}
}

llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;

int max_threads = std::thread::hardware_concurrency();
Expand Down Expand Up @@ -853,7 +878,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
return reinterpret_cast<jobject>(result);
}
llama->beginCompletion();
llama->loadPrompt();
llama->loadPrompt(image_paths_vector);

if (llama->context_full) {
auto result = createWriteableMap(env);
Expand Down Expand Up @@ -922,7 +947,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
}

env->ReleaseStringUTFChars(grammar, grammar_chars);
env->ReleaseStringUTFChars(prompt, prompt_chars);

// Release prompt_chars if it's still allocated
if (prompt_chars != nullptr) {
env->ReleaseStringUTFChars(prompt, prompt_chars);
}

llama_perf_context_print(llama->ctx);
llama->is_predicting = false;

Expand Down Expand Up @@ -1098,7 +1128,7 @@ Java_com_rnllama_LlamaContext_embedding(
}

llama->beginCompletion();
llama->loadPrompt();
llama->loadPrompt({});
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding(embdParams);
Expand Down Expand Up @@ -1267,4 +1297,46 @@ Java_com_rnllama_LlamaContext_unsetLog(JNIEnv *env, jobject thiz) {
llama_log_set(rnllama_log_callback_default, NULL);
}

JNIEXPORT jboolean JNICALL
Java_com_rnllama_LlamaContext_initMultimodal(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jstring mmproj_path,
jboolean mmproj_use_gpu
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

const char *mmproj_path_chars = env->GetStringUTFChars(mmproj_path, nullptr);
bool result = llama->initMultimodal(mmproj_path_chars, mmproj_use_gpu);
env->ReleaseStringUTFChars(mmproj_path, mmproj_path_chars);

return result;
}

JNIEXPORT jboolean JNICALL
Java_com_rnllama_LlamaContext_isMultimodalEnabled(
JNIEnv *env,
jobject thiz,
jlong context_ptr
) {
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
return llama->isMultimodalEnabled();
}

JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_releaseMultimodal(
JNIEnv *env,
jobject thiz,
jlong context_ptr
) {
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
llama->releaseMultimodal();
}

} // extern "C"
15 changes: 15 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ public void initContext(double id, final ReadableMap params, final Promise promi
rnllama.initContext(id, params, promise);
}

@ReactMethod
public void initMultimodal(double id, final ReadableMap params, final Promise promise) {
rnllama.initMultimodal(id, params, promise);
}

@ReactMethod
public void isMultimodalEnabled(double id, final Promise promise) {
rnllama.isMultimodalEnabled(id, promise);
}

@ReactMethod
public void releaseMultimodal(double id, final Promise promise) {
rnllama.releaseMultimodal(id, promise);
}

@ReactMethod
public void getFormattedChat(double id, String messages, String chatTemplate, ReadableMap params, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, params, promise);
Expand Down
15 changes: 15 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ public void initContext(double id, final ReadableMap params, final Promise promi
rnllama.initContext(id, params, promise);
}

@ReactMethod
public void initMultimodal(double id, final ReadableMap params, final Promise promise) {
rnllama.initMultimodal(id, params, promise);
}

@ReactMethod
public void isMultimodalEnabled(double id, final Promise promise) {
rnllama.isMultimodalEnabled(id, promise);
}

@ReactMethod
public void releaseMultimodal(double id, final Promise promise) {
rnllama.releaseMultimodal(id, promise);
}

@ReactMethod
public void getFormattedChat(double id, String messages, String chatTemplate, ReadableMap params, Promise promise) {
rnllama.getFormattedChat(id, messages, chatTemplate, params, promise);
Expand Down
4 changes: 3 additions & 1 deletion cpp/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
msgs.push_back(msg);
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
// @ngxson : disable otherwise it's bloating the API response
// printf("%s\n", std::string("; messages = ") + messages.dump(2));
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()));
}

return msgs;
Expand Down
1 change: 0 additions & 1 deletion cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
Expand Down
Loading