@@ -587,34 +587,32 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
587
587
}
588
588
}
589
589
590
- switch llm .KV ().Architecture () {
591
- case "mllama" :
592
- kv := func (n string ) uint64 {
593
- if v , ok := llm .KV ()["mllama.vision." + n ].(uint32 ); ok {
594
- return uint64 (v )
595
- }
596
-
597
- return 0
598
- }
590
+ imageSize := uint64 (llm .KV ().Uint ("vision.image_size" ))
591
+ patchSize := uint64 (llm .KV ().Uint ("vision.patch_size" ))
599
592
600
- imageSize := kv ("image_size" )
593
+ numPatches := (imageSize / patchSize ) * (imageSize / patchSize )
594
+ if _ , ok := llm .Tensors ().GroupLayers ()["v" ]["class_embd" ]; ok {
595
+ numPatches ++
596
+ }
601
597
602
- maxNumTiles := kv ("max_num_tiles" )
603
- embeddingLength := kv ("embedding_length" )
604
- headCount := kv ("attention.head_count" )
598
+ headCount := uint64 (llm .KV ().Uint ("vision.attention.head_count" ))
605
599
606
- numPatches := (imageSize / kv ("patch_size" )) * (imageSize / kv ("patch_size" ))
607
- if _ , ok := llm .Tensors ().GroupLayers ()["v" ]["class_embd" ]; ok {
608
- numPatches ++
609
- }
600
+ switch llm .KV ().Architecture () {
601
+ case "mllama" :
610
602
611
603
numPaddedPatches := numPatches + 8 - (numPatches % 8 )% 8
612
604
605
+ maxNumTiles := uint64 (llm .KV ().Uint ("vision.max_num_tiles" ))
606
+ numChannels := uint64 (llm .KV ().Uint ("vision.num_channels" ))
607
+ embeddingLength := uint64 (llm .KV ().Uint ("vision.embedding_length" ))
608
+
613
609
graphSize = 4 * (8 +
614
- imageSize * imageSize * kv ( "num_channels" ) * maxNumTiles +
610
+ imageSize * imageSize * numChannels * maxNumTiles +
615
611
embeddingLength * numPatches * maxNumTiles +
616
612
9 * embeddingLength * numPaddedPatches * maxNumTiles +
617
613
numPaddedPatches * maxNumTiles * numPaddedPatches * maxNumTiles * headCount )
614
+ case "gemma3" :
615
+ graphSize = 4 * (numPatches * numPatches * headCount )
618
616
}
619
617
620
618
return weights , graphSize
0 commit comments