Skip to content

Commit 92ecdcc

Browse files
authored
mtmd : add vision support for llama 4 (#13282)
* wip llama 4 conversion * rm redundant __init__ * fix conversion * fix conversion * test impl * try this * reshape patch_embeddings_0 * fix view * rm ffn_post_norm * cgraph ok * f32 for pos embd * add image marker tokens * Llama4UnfoldConvolution * correct pixel shuffle * fix merge conflicts * correct * add debug_graph * logits matched, but it still preceives the image incorrectly * fix style * add image_grid_pinpoints * handle llama 4 preprocessing * rm load_image_size * rm unused line * fix * small fix 2 * add test & docs * fix llava-1.6 test * test: add notion of huge models * add comment * add warn about degraded quality
1 parent f71f40a commit 92ecdcc

File tree

9 files changed

+424
-82
lines changed

9 files changed

+424
-82
lines changed

convert_hf_to_gguf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def prepare_tensors(self):
308308
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
309309
gguf.MODEL_TENSOR.POSNET_NORM1,
310310
gguf.MODEL_TENSOR.POSNET_NORM2,
311+
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
311312
)
312313
)
313314
or not new_name.endswith(".weight")
@@ -2092,6 +2093,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
20922093
return super().modify_tensors(data_torch, name, bid)
20932094

20942095

2096+
@ModelBase.register("Llama4ForConditionalGeneration")
2097+
class Llama4VisionModel(VisionModel):
2098+
def set_gguf_parameters(self):
2099+
super().set_gguf_parameters()
2100+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
2101+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
2102+
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
2103+
assert self.hparams["hidden_act"] == "gelu"
2104+
self.gguf_writer.add_vision_use_gelu(True)
2105+
2106+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2107+
del bid # unused
2108+
if "multi_modal_projector" in name or "vision_model" in name:
2109+
# process vision tensors
2110+
if "positional_embedding_vlm" in name and ".weight" not in name:
2111+
name += ".weight"
2112+
return [(self.map_tensor_name(name), data_torch)]
2113+
return []
2114+
2115+
20952116
@ModelBase.register("Mistral3ForConditionalGeneration")
20962117
class Mistral3Model(LlamaModel):
20972118
model_arch = gguf.MODEL_ARCH.LLAMA

docs/multimodal.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,7 @@ NOTE: some models may require large context window, for example: `-c 8192`
7474
(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
7575
(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF
7676
(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
77+
78+
# Llama 4 Scout
79+
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
7780
```

gguf-py/gguf/constants.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,15 @@ class MODEL_TENSOR(IntEnum):
482482
V_ENC_EMBD_CLS = auto()
483483
V_ENC_EMBD_PATCH = auto()
484484
V_ENC_EMBD_POS = auto()
485+
V_ENC_INPUT_NORM = auto()
485486
V_ENC_ATTN_Q = auto()
486487
V_ENC_ATTN_Q_NORM = auto()
487488
V_ENC_ATTN_K = auto()
488489
V_ENC_ATTN_K_NORM = auto()
489490
V_ENC_ATTN_V = auto()
490-
V_ENC_INPUT_NORM = auto()
491-
V_ENC_OUTPUT = auto()
492-
V_ENC_OUTPUT_NORM = auto()
491+
V_ENC_ATTN_O = auto()
492+
V_ENC_ATTN_O_NORM = auto()
493+
V_ENC_POST_ATTN_NORM = auto()
493494
V_ENC_FFN_UP = auto()
494495
V_ENC_FFN_GATE = auto()
495496
V_ENC_FFN_DOWN = auto()
@@ -749,8 +750,9 @@ class MODEL_TENSOR(IntEnum):
749750
MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
750751
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
751752
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
752-
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
753-
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
753+
MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
754+
MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
755+
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
754756
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
755757
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
756758
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
@@ -785,14 +787,15 @@ class MODEL_TENSOR(IntEnum):
785787
MODEL_TENSOR.V_ENC_EMBD_CLS,
786788
MODEL_TENSOR.V_ENC_EMBD_PATCH,
787789
MODEL_TENSOR.V_ENC_EMBD_POS,
790+
MODEL_TENSOR.V_ENC_INPUT_NORM,
788791
MODEL_TENSOR.V_ENC_ATTN_Q,
789792
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
790793
MODEL_TENSOR.V_ENC_ATTN_K,
791794
MODEL_TENSOR.V_ENC_ATTN_K_NORM,
792795
MODEL_TENSOR.V_ENC_ATTN_V,
793-
MODEL_TENSOR.V_ENC_INPUT_NORM,
794-
MODEL_TENSOR.V_ENC_OUTPUT,
795-
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
796+
MODEL_TENSOR.V_ENC_ATTN_O,
797+
MODEL_TENSOR.V_ENC_ATTN_O_NORM,
798+
MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
796799
MODEL_TENSOR.V_ENC_FFN_UP,
797800
MODEL_TENSOR.V_ENC_FFN_GATE,
798801
MODEL_TENSOR.V_ENC_FFN_DOWN,
@@ -2180,6 +2183,7 @@ class VisionProjectorType:
21802183
GEMMA3 = "gemma3"
21812184
IDEFICS3 = "idefics3"
21822185
PIXTRAL = "pixtral"
2186+
LLAMA4 = "llama4"
21832187
QWEN2VL = "qwen2vl_merger"
21842188
QWEN25VL = "qwen2.5vl_merger"
21852189
INTERNVL = "internvl"

gguf-py/gguf/tensor_mapping.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,12 @@ class TensorNameMap:
902902

903903
MODEL_TENSOR.V_MMPROJ_FC: (
904904
"model.connector.modality_projection.proj", # SmolVLM
905+
"multi_modal_projector.linear_1", # llama 4
905906
),
906907

907908
MODEL_TENSOR.V_MMPROJ_MLP: (
908909
"model.mm_projector.mlp.mlp.{bid}",
910+
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
909911
"mlp1.{bid}", # InternVL
910912
),
911913

@@ -915,26 +917,30 @@ class TensorNameMap:
915917

916918
MODEL_TENSOR.V_ENC_EMBD_CLS: (
917919
"vision_tower.vision_model.embeddings.class_embedding",
920+
"vision_model.class_embedding", # llama 4
918921
),
919922

920923
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
921924
"vision_tower.vision_model.embeddings.patch_embedding",
922925
"vpm.embeddings.patch_embedding",
923926
"model.vision_model.embeddings.patch_embedding", # SmolVLM
924927
"vision_tower.patch_conv", # pixtral
928+
"vision_model.patch_embedding.linear", # llama 4
925929
"visual.patch_embed.proj", # qwen2vl
926930
),
927931

928932
MODEL_TENSOR.V_ENC_EMBD_POS: (
929933
"vision_tower.vision_model.embeddings.position_embedding",
930934
"vpm.embeddings.position_embedding",
931935
"model.vision_model.embeddings.position_embedding", # SmolVLM
936+
"vision_model.positional_embedding_vlm", # llama 4
932937
),
933938

934939
MODEL_TENSOR.V_ENC_ATTN_Q: (
935940
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
936941
"vpm.encoder.layers.{bid}.self_attn.q_proj",
937942
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
943+
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
938944
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
939945
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
940946
),
@@ -947,6 +953,7 @@ class TensorNameMap:
947953
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
948954
"vpm.encoder.layers.{bid}.self_attn.k_proj",
949955
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
956+
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
950957
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
951958
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
952959
),
@@ -959,6 +966,7 @@ class TensorNameMap:
959966
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
960967
"vpm.encoder.layers.{bid}.self_attn.v_proj",
961968
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
969+
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
962970
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
963971
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
964972
),
@@ -969,23 +977,26 @@ class TensorNameMap:
969977
"vpm.encoder.layers.{bid}.layer_norm1",
970978
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
971979
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
980+
"vision_model.model.layers.{bid}.input_layernorm", # llama4
972981
"visual.blocks.{bid}.norm1", # qwen2vl
973982
),
974983

975-
MODEL_TENSOR.V_ENC_OUTPUT: (
984+
MODEL_TENSOR.V_ENC_ATTN_O: (
976985
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
977986
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
978987
"vpm.encoder.layers.{bid}.self_attn.out_proj",
979988
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
989+
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
980990
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
981991
"visual.blocks.{bid}.attn.proj", # qwen2vl
982992
),
983993

984-
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
994+
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
985995
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
986996
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
987997
"vpm.encoder.layers.{bid}.layer_norm2",
988998
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
999+
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
9891000
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
9901001
"visual.blocks.{bid}.norm2", # qwen2vl
9911002
),
@@ -995,6 +1006,7 @@ class TensorNameMap:
9951006
"vpm.encoder.layers.{bid}.mlp.fc1",
9961007
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
9971008
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
1009+
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
9981010
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
9991011
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
10001012
),
@@ -1009,6 +1021,7 @@ class TensorNameMap:
10091021
"vpm.encoder.layers.{bid}.mlp.fc2",
10101022
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
10111023
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
1024+
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
10121025
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
10131026
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
10141027
),
@@ -1024,11 +1037,13 @@ class TensorNameMap:
10241037
MODEL_TENSOR.V_PRE_NORM: (
10251038
"vision_tower.vision_model.pre_layrnorm",
10261039
"vision_tower.ln_pre", # pixtral
1040+
"vision_model.layernorm_pre", # llama4
10271041
),
10281042

10291043
MODEL_TENSOR.V_POST_NORM: (
10301044
"vision_tower.vision_model.post_layernorm",
10311045
"model.vision_model.post_layernorm", # SmolVLM
1046+
"vision_model.layernorm_post", # llama4
10321047
"visual.merger.ln_q", # qwen2vl
10331048
),
10341049

tools/mtmd/clip-impl.h

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <climits>
66
#include <cstdarg>
7+
#include <cinttypes>
78
#include <string>
89
#include <map>
910
#include <sstream>
@@ -44,7 +45,7 @@
4445
// tensor name constants
4546
//
4647

47-
#define TN_POS_EMBD "%s.position_embd.weight"
48+
#define TN_POS_EMBD "v.position_embd.weight"
4849
#define TN_CLASS_EMBD "v.class_embd"
4950
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
5051
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
@@ -110,6 +111,7 @@ enum projector_type {
110111
PROJECTOR_TYPE_PIXTRAL,
111112
PROJECTOR_TYPE_QWEN25VL,
112113
PROJECTOR_TYPE_INTERNVL,
114+
PROJECTOR_TYPE_LLAMA4,
113115
PROJECTOR_TYPE_UNKNOWN,
114116
};
115117

@@ -125,6 +127,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
125127
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
126128
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
127129
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
130+
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
128131
};
129132

130133
static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -240,6 +243,11 @@ struct clip_image_u8_batch {
240243
struct clip_image_f32_batch {
241244
std::vector<clip_image_f32_ptr> entries;
242245

246+
// for llava-uhd style models, we need to know the grid size
247+
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
248+
int grid_x = 0;
249+
int grid_y = 0;
250+
243251
clip_image_f32_batch clone() const {
244252
clip_image_f32_batch new_batch;
245253
new_batch.entries.reserve(entries.size());
@@ -358,6 +366,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
358366
}
359367
}
360368

369+
//
370+
// debugging
371+
//
372+
373+
static void print_tensor_shape(ggml_tensor * t) {
374+
printf("%s.shape = [", t->name);
375+
for (int i = 0; i < ggml_n_dims(t); ++i) {
376+
printf("%" PRId64, t->ne[i]);
377+
if (i < ggml_n_dims(t) - 1) {
378+
printf(", ");
379+
}
380+
}
381+
printf("]\n");
382+
}
383+
384+
static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
385+
ggml_type type = t->type;
386+
int64_t * ne = t->ne;
387+
size_t * nb = t->nb;
388+
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
389+
printf("%s.data: [\n", t->name);
390+
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
391+
if (i2 == n && ne[2] > 2*n) {
392+
printf(" ..., \n");
393+
i2 = ne[2] - n;
394+
}
395+
printf(" [\n");
396+
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
397+
if (i1 == n && ne[1] > 2*n) {
398+
printf(" ..., \n");
399+
i1 = ne[1] - n;
400+
}
401+
printf(" [");
402+
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
403+
if (i0 == n && ne[0] > 2*n) {
404+
printf("..., ");
405+
i0 = ne[0] - n;
406+
}
407+
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
408+
float v;
409+
if (type == GGML_TYPE_F16) {
410+
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
411+
} else if (type == GGML_TYPE_F32) {
412+
v = *(float *) &data[i];
413+
} else if (type == GGML_TYPE_I32) {
414+
v = (float) *(int32_t *) &data[i];
415+
} else if (type == GGML_TYPE_I16) {
416+
v = (float) *(int16_t *) &data[i];
417+
} else if (type == GGML_TYPE_I8) {
418+
v = (float) *(int8_t *) &data[i];
419+
} else {
420+
GGML_ABORT("fatal error");
421+
}
422+
printf("%8.4f", v);
423+
if (i0 < ne[0] - 1) printf(", ");
424+
}
425+
printf("],\n");
426+
}
427+
printf(" ],\n");
428+
}
429+
printf(" ]\n");
430+
}
431+
}
432+
361433
//
362434
// API used internally with mtmd
363435
//

0 commit comments

Comments
 (0)