Skip to content

Commit a422ba3

Browse files
committed
roughly count gemma3 graph
the largest operation is by far (q @ k) so just count that for simplicity
1 parent d2ec223 commit a422ba3

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

fs/ggml/ggml.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -587,34 +587,32 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
587587
}
588588
}
589589

590-
switch llm.KV().Architecture() {
591-
case "mllama":
592-
kv := func(n string) uint64 {
593-
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
594-
return uint64(v)
595-
}
596-
597-
return 0
598-
}
590+
imageSize := uint64(llm.KV().Uint("vision.image_size"))
591+
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
599592

600-
imageSize := kv("image_size")
593+
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
594+
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
595+
numPatches++
596+
}
601597

602-
maxNumTiles := kv("max_num_tiles")
603-
embeddingLength := kv("embedding_length")
604-
headCount := kv("attention.head_count")
598+
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
605599

606-
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
607-
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
608-
numPatches++
609-
}
600+
switch llm.KV().Architecture() {
601+
case "mllama":
610602

611603
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
612604

605+
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
606+
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
607+
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
608+
613609
graphSize = 4 * (8 +
614-
imageSize*imageSize*kv("num_channels")*maxNumTiles +
610+
imageSize*imageSize*numChannels*maxNumTiles +
615611
embeddingLength*numPatches*maxNumTiles +
616612
9*embeddingLength*numPaddedPatches*maxNumTiles +
617613
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
614+
case "gemma3":
615+
graphSize = 4 * (numPatches * numPatches * headCount)
618616
}
619617

620618
return weights, graphSize

0 commit comments

Comments
 (0)