Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Llava as a shared library for downstream projects #3613

Merged
merged 34 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0209d39
wip llava python bindings compatibility
damian0815 Oct 13, 2023
3c10d9f
add external llava API
damian0815 Oct 13, 2023
770dc9d
add base64 in-prompt image support
damian0815 Oct 13, 2023
8224ca5
wip refactor image loading
damian0815 Oct 14, 2023
c693208
refactor image load out of llava init
damian0815 Oct 14, 2023
0889117
cleanup
damian0815 Oct 14, 2023
f83c060
further cleanup; move llava-cli into its own file and rename
damian0815 Oct 14, 2023
e2cd07c
move base64.hpp into common/
damian0815 Oct 14, 2023
f8eddcf
collapse clip and llava libraries
damian0815 Oct 14, 2023
b9f533b
move llava into its own subdir
damian0815 Oct 14, 2023
f21af51
wip
damian0815 Oct 14, 2023
708928c
fix bug where base64 string was not removed from the prompt
damian0815 Oct 14, 2023
09edb7e
get libllava to output in the right place
damian0815 Oct 14, 2023
2847ecf
expose llava methods in libllama.dylib
damian0815 Oct 14, 2023
e3261ff
cleanup memory usage around clip_image_*
damian0815 Oct 14, 2023
d64891b
cleanup and refactor *again*
damian0815 Oct 15, 2023
5a91551
update headerdoc
damian0815 Oct 15, 2023
e84003b
Move llava back to examples
monatis Nov 2, 2023
8037034
build with cmake, not tested (WIP)
monatis Nov 2, 2023
52143f7
Editorconfig
monatis Nov 5, 2023
c6b8844
Merge branch 'master' into llava-lib
monatis Nov 5, 2023
32bf7bf
Editorconfig
monatis Nov 5, 2023
53dca51
Build with make
monatis Nov 5, 2023
b927772
Build with make
monatis Nov 5, 2023
01f06e2
Fix cyclical depts on Windows
monatis Nov 5, 2023
ad97e0e
attempt to fix build on Windows
monatis Nov 5, 2023
71ea278
Merge branch 'master' into llava-lib
monatis Nov 5, 2023
1f8c866
attempt to fix build on Windows
monatis Nov 6, 2023
d6be69f
Upd TODOs
monatis Nov 6, 2023
5b8b9ef
attempt to fix build on Windows+CUDA
monatis Nov 6, 2023
b9bacc7
Revert changes in cmake
monatis Nov 6, 2023
9f03ac7
Fix according to review comments
monatis Nov 6, 2023
22f43fc
Support building as a shared library
monatis Nov 6, 2023
3548029
address review comments
cebtenzzre Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip refactor image loading
  • Loading branch information
damian0815 committed Oct 14, 2023
commit 8224ca5775b7f09f088abf2379fcac25270085d4
52 changes: 52 additions & 0 deletions examples/llava/llava-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,55 @@ inline const char * sample(struct llama_context * ctx_llama, gpt_params & params
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}

static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
static const char* IMG_BASE64_TAG_END = "\">";

static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
}

static bool prompt_contains_image(const std::string& prompt) {
size_t begin, end;
find_image_tag_in_prompt(prompt, begin, end);
return (begin != std::string::npos);
}

// replaces the base64 image tag in the prompt with `replacement`
static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
size_t img_base64_str_start, img_base64_str_end;
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
return false;
}

auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );

auto required_bytes = base64::required_encode_size(base64_str.size());
auto img_bytes = std::vector<unsigned char>(required_bytes);
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
auto img_bytes_len = img_bytes_end - img_bytes.begin();

auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
if (!img_loaded_ok) {
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
return false;
}

return true;
}

static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
size_t begin, end;
find_image_tag_in_prompt(prompt, begin, end);
if (begin == std::string::npos || end == std::string::npos) {
return prompt;
}
auto pre = prompt.substr(0, begin);
auto post = prompt.substr(end+1);
return pre + replacement + post;
}
122 changes: 39 additions & 83 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,58 +37,28 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
return true;
}

static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
static const char* IMG_BASE64_TAG_END = "\">";
bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) {

static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
}

static bool prompt_contains_image(const std::string& prompt) {
size_t begin, end;
find_image_tag_in_prompt(prompt, begin, end);
return (begin != std::string::npos);
}

// replaces the base64 image tag in the prompt with `replacement`
static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
size_t img_base64_str_start, img_base64_str_end;
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
free(image_embd);
return false;
}

auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
printf("base64_str: '%s'\n", base64_str.c_str());

auto required_bytes = base64::required_encode_size(base64_str.size());
auto img_bytes = std::vector<unsigned char>(required_bytes);
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
auto img_bytes_len = img_bytes_end - img_bytes.begin();

auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
if (!img_loaded_ok) {
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
int n_img_embd;
int n_img_pos;
float t_img_enc_ms;
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
free(image_embd);
return false;
}

return true;
ctx_llava->image_embd = image_embd;
retur true;
}

static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
size_t begin, end;
find_image_tag_in_prompt(prompt, begin, end);
if (begin == std::string::npos || end == std::string::npos) {
return prompt;
}
auto pre = prompt.substr(0, begin);
auto post = prompt.substr(end+1);
return pre + replacement + post;
}

struct llava_context * llava_init(gpt_params * params) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image loading and inference parts should be stripped off this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


Expand All @@ -102,46 +72,6 @@ struct llava_context * llava_init(gpt_params * params) {

auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);

// load and preprocess the image
clip_image_u8 img;

if (prompt_contains_image(prompt)) {
if (img_path) {
printf("using base64 encoded image instead of command line image path\n");
}
if (!get_image_from_prompt(prompt, &img)) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
clip_free(ctx_clip);
return NULL;
}
prompt = remove_image_from_prompt(prompt);
} else {
if (!clip_image_load_from_file(img_path, &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
clip_free(ctx_clip);
return NULL;
}
}

float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
return NULL;
}

int n_img_embd;
int n_img_pos;
float t_img_enc_ms;
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
clip_free(ctx_clip);
return NULL;
}

// we get the embeddings, free up the memory required for CLIP
clip_free(ctx_clip);
ctx_clip = NULL;

llama_backend_init(params->numa);

llama_model_params model_params = llama_model_default_params();
Expand Down Expand Up @@ -194,6 +124,11 @@ struct llava_context * llava_init(gpt_params * params) {
}

void llava_free(struct llava_context * ctx_llava) {
if (ctx_llava->ctx_clip) {
clip_free(ctx_clip);
ctx_llava->ctx_clip = NULL;
}

llama_free(ctx_llava->ctx_llama);
llama_free_model(ctx_llava->model);
llama_backend_free();
Expand Down Expand Up @@ -249,6 +184,27 @@ int main(int argc, char ** argv) {
return 1;
}

// load and preprocess the image
clip_image_u8 img;
if (prompt_contains_image(prompt)) {
if (img_path) {
printf("using base64 encoded image instead of command line image path\n");
}
if (!get_image_from_prompt(prompt, &img)) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
clip_free(ctx_clip);
return NULL;
}
prompt = remove_image_from_prompt(prompt);
} else {
if (!clip_image_load_from_file(img_path, &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
clip_free(ctx_clip);
return NULL;
}
}
llava_build_img_embed(ctx_llava, &img);

// process the prompt
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
llava_process_prompt(ctx_llava, &params, params.prompt.c_str());
Expand Down