@@ -579,6 +579,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
579
579
}
580
580
581
581
func (llm GGML ) VisionGraphSize () (weights , graphSize uint64 ) {
582
+ if llm .KV ().Uint ("vision.block_count" ) == 0 {
583
+ return
584
+ }
585
+
582
586
for name , layer := range llm .Tensors ().GroupLayers () {
583
587
if strings .HasPrefix (name , "v." ) {
584
588
for _ , tensor := range layer {
@@ -589,30 +593,36 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
589
593
590
594
imageSize := uint64 (llm .KV ().Uint ("vision.image_size" ))
591
595
patchSize := uint64 (llm .KV ().Uint ("vision.patch_size" ))
596
+ if patchSize == 0 {
597
+ slog .Warn ("unknown patch size for vision model" )
598
+ return
599
+ }
600
+
601
+ numChannels := uint64 (llm .KV ().Uint ("vision.num_channels" ))
592
602
593
603
numPatches := (imageSize / patchSize ) * (imageSize / patchSize )
594
604
if _ , ok := llm .Tensors ().GroupLayers ()["v" ]["class_embd" ]; ok {
595
605
numPatches ++
596
606
}
597
607
598
608
headCount := uint64 (llm .KV ().Uint ("vision.attention.head_count" ))
609
+ embeddingLength := uint64 (llm .KV ().Uint ("vision.embedding_length" ))
599
610
600
611
switch llm .KV ().Architecture () {
601
612
case "mllama" :
602
-
603
613
numPaddedPatches := numPatches + 8 - (numPatches % 8 )% 8
604
614
605
615
maxNumTiles := uint64 (llm .KV ().Uint ("vision.max_num_tiles" ))
606
- numChannels := uint64 (llm .KV ().Uint ("vision.num_channels" ))
607
- embeddingLength := uint64 (llm .KV ().Uint ("vision.embedding_length" ))
608
616
609
617
graphSize = 4 * (8 +
610
618
imageSize * imageSize * numChannels * maxNumTiles +
611
619
embeddingLength * numPatches * maxNumTiles +
612
620
9 * embeddingLength * numPaddedPatches * maxNumTiles +
613
621
numPaddedPatches * maxNumTiles * numPaddedPatches * maxNumTiles * headCount )
614
622
case "gemma3" :
615
- graphSize = 4 * (numPatches * numPatches * headCount )
623
+ graphSize = 4 * (imageSize * imageSize * numChannels +
624
+ embeddingLength * patchSize +
625
+ numPatches * numPatches * headCount )
616
626
}
617
627
618
628
return weights , graphSize
0 commit comments