Skip to content

Commit 8de3bf5

Browse files
ngxsonggerganov
authored andcommitted
llama-quant: add support for mmproj (ggml-org#16592)
* llama-quant: add support for mmproj * Update src/llama.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * check prefix instead * small fix --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 03780eb commit 8de3bf5

File tree

5 files changed

+19
-2
lines changed

5 files changed

+19
-2
lines changed

src/llama-arch.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <map>
66

77
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8+
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
89
{ LLM_ARCH_LLAMA, "llama" },
910
{ LLM_ARCH_LLAMA4, "llama4" },
1011
{ LLM_ARCH_DECI, "deci" },
@@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
275276
};
276277

277278
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
279+
{
280+
LLM_ARCH_CLIP,
281+
{},
282+
},
278283
{
279284
LLM_ARCH_LLAMA,
280285
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//
1010

1111
enum llm_arch {
12+
LLM_ARCH_CLIP,
1213
LLM_ARCH_LLAMA,
1314
LLM_ARCH_LLAMA4,
1415
LLM_ARCH_DECI,

src/llama-model.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
478478
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
479479

480480
// everything past this point is not vocab-related
481-
if (hparams.vocab_only) {
481+
// for CLIP models, we only need to load tensors, no hparams
482+
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
482483
return;
483484
}
484485

@@ -20013,6 +20014,7 @@ int32_t llama_n_head(const llama_model * model) {
2001320014
llama_rope_type llama_model_rope_type(const llama_model * model) {
2001420015
switch (model->arch) {
2001520016
// these models do not use RoPE
20017+
case LLM_ARCH_CLIP:
2001620018
case LLM_ARCH_GPT2:
2001720019
case LLM_ARCH_GPTJ:
2001820020
case LLM_ARCH_MPT:

src/llama-quant.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
701701
});
702702
}
703703

704+
bool is_clip_model = false;
704705
for (const auto * it : tensors) {
705706
const struct ggml_tensor * tensor = it->tensor;
706707

@@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
714715
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
715716
qs.has_output = true;
716717
}
718+
719+
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
717720
}
718721

719722
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
720723

721724
// sanity checks for models that have attention layers
722-
if (qs.n_attention_wv != 0)
725+
if (qs.n_attention_wv != 0 && !is_clip_model)
723726
{
724727
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
725728
// attention layers have a non-zero number of kv heads
@@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
881884
// do not quantize relative position bias (T5)
882885
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
883886

887+
// do not quantize specific multimodal tensors
888+
quantize &= name.find(".position_embd.") == std::string::npos;
889+
884890
ggml_type new_type;
885891
void * new_data;
886892
size_t new_size;

src/llama.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
124124
} catch(const std::exception & e) {
125125
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
126126
}
127+
if (model.arch == LLM_ARCH_CLIP) {
128+
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
129+
}
127130
try {
128131
model.load_vocab(ml);
129132
} catch(const std::exception & e) {

0 commit comments

Comments
 (0)