11import torch
22import torch .nn as nn
33from dataclasses import dataclass
4- from typing import Optional , Any
4+ from typing import Optional , Any , Tuple
55import math
66
77from comfy .ldm .modules .attention import optimized_attention_for_device
@@ -32,6 +32,7 @@ class Llama2Config:
3232 k_norm = None
3333 rope_scale = None
3434 final_norm : bool = True
35+ lm_head : bool = False
3536
3637@dataclass
3738class Mistral3Small24BConfig :
@@ -54,6 +55,7 @@ class Mistral3Small24BConfig:
5455 k_norm = None
5556 rope_scale = None
5657 final_norm : bool = True
58+ lm_head : bool = False
5759
5860@dataclass
5961class Qwen25_3BConfig :
@@ -76,6 +78,7 @@ class Qwen25_3BConfig:
7678 k_norm = None
7779 rope_scale = None
7880 final_norm : bool = True
81+ lm_head : bool = False
7982
8083@dataclass
8184class Qwen3_06BConfig :
@@ -98,6 +101,7 @@ class Qwen3_06BConfig:
98101 k_norm = "gemma3"
99102 rope_scale = None
100103 final_norm : bool = True
104+ lm_head : bool = False
101105
102106@dataclass
103107class Qwen3_4BConfig :
@@ -120,6 +124,7 @@ class Qwen3_4BConfig:
120124 k_norm = "gemma3"
121125 rope_scale = None
122126 final_norm : bool = True
127+ lm_head : bool = False
123128
124129@dataclass
125130class Qwen3_8BConfig :
@@ -142,6 +147,7 @@ class Qwen3_8BConfig:
142147 k_norm = "gemma3"
143148 rope_scale = None
144149 final_norm : bool = True
150+ lm_head : bool = False
145151
146152@dataclass
147153class Ovis25_2BConfig :
@@ -164,6 +170,7 @@ class Ovis25_2BConfig:
164170 k_norm = "gemma3"
165171 rope_scale = None
166172 final_norm : bool = True
173+ lm_head : bool = False
167174
168175@dataclass
169176class Qwen25_7BVLI_Config :
@@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config:
186193 k_norm = None
187194 rope_scale = None
188195 final_norm : bool = True
196+ lm_head : bool = False
189197
190198@dataclass
191199class Gemma2_2B_Config :
@@ -209,6 +217,7 @@ class Gemma2_2B_Config:
209217 sliding_attention = None
210218 rope_scale = None
211219 final_norm : bool = True
220+ lm_head : bool = False
212221
213222@dataclass
214223class Gemma3_4B_Config :
@@ -232,6 +241,7 @@ class Gemma3_4B_Config:
232241 sliding_attention = [1024 , 1024 , 1024 , 1024 , 1024 , False ]
233242 rope_scale = [8.0 , 1.0 ]
234243 final_norm : bool = True
244+ lm_head : bool = False
235245
236246@dataclass
237247class Gemma3_12B_Config :
@@ -255,6 +265,7 @@ class Gemma3_12B_Config:
255265 sliding_attention = [1024 , 1024 , 1024 , 1024 , 1024 , False ]
256266 rope_scale = [8.0 , 1.0 ]
257267 final_norm : bool = True
268+ lm_head : bool = False
258269 vision_config = {"num_channels" : 3 , "hidden_act" : "gelu_pytorch_tanh" , "hidden_size" : 1152 , "image_size" : 896 , "intermediate_size" : 4304 , "model_type" : "siglip_vision_model" , "num_attention_heads" : 16 , "num_hidden_layers" : 27 , "patch_size" : 14 }
259270 mm_tokens_per_image = 256
260271
@@ -356,6 +367,7 @@ def forward(
356367 attention_mask : Optional [torch .Tensor ] = None ,
357368 freqs_cis : Optional [torch .Tensor ] = None ,
358369 optimized_attention = None ,
370+ past_key_value : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
359371 ):
360372 batch_size , seq_length , _ = hidden_states .shape
361373 xq = self .q_proj (hidden_states )
@@ -373,11 +385,30 @@ def forward(
373385
374386 xq , xk = apply_rope (xq , xk , freqs_cis = freqs_cis )
375387
388+ present_key_value = None
389+ if past_key_value is not None :
390+ index = 0
391+ num_tokens = xk .shape [2 ]
392+ if len (past_key_value ) > 0 :
393+ past_key , past_value , index = past_key_value
394+ if past_key .shape [2 ] >= (index + num_tokens ):
395+ past_key [:, :, index :index + xk .shape [2 ]] = xk
396+ past_value [:, :, index :index + xv .shape [2 ]] = xv
397+ xk = past_key [:, :, :index + xk .shape [2 ]]
398+ xv = past_value [:, :, :index + xv .shape [2 ]]
399+ present_key_value = (past_key , past_value , index + num_tokens )
400+ else :
401+ xk = torch .cat ((past_key [:, :, :index ], xk ), dim = 2 )
402+ xv = torch .cat ((past_value [:, :, :index ], xv ), dim = 2 )
403+ present_key_value = (xk , xv , index + num_tokens )
404+ else :
405+ present_key_value = (xk , xv , index + num_tokens )
406+
376407 xk = xk .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
377408 xv = xv .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
378409
379410 output = optimized_attention (xq , xk , xv , self .num_heads , mask = attention_mask , skip_reshape = True )
380- return self .o_proj (output )
411+ return self .o_proj (output ), present_key_value
381412
382413class MLP (nn .Module ):
383414 def __init__ (self , config : Llama2Config , device = None , dtype = None , ops : Any = None ):
@@ -408,15 +439,17 @@ def forward(
408439 attention_mask : Optional [torch .Tensor ] = None ,
409440 freqs_cis : Optional [torch .Tensor ] = None ,
410441 optimized_attention = None ,
442+ past_key_value : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
411443 ):
412444 # Self Attention
413445 residual = x
414446 x = self .input_layernorm (x )
415- x = self .self_attn (
447+ x , present_key_value = self .self_attn (
416448 hidden_states = x ,
417449 attention_mask = attention_mask ,
418450 freqs_cis = freqs_cis ,
419451 optimized_attention = optimized_attention ,
452+ past_key_value = past_key_value ,
420453 )
421454 x = residual + x
422455
@@ -426,7 +459,7 @@ def forward(
426459 x = self .mlp (x )
427460 x = residual + x
428461
429- return x
462+ return x , present_key_value
430463
431464class TransformerBlockGemma2 (nn .Module ):
432465 def __init__ (self , config : Llama2Config , index , device = None , dtype = None , ops : Any = None ):
@@ -451,6 +484,7 @@ def forward(
451484 attention_mask : Optional [torch .Tensor ] = None ,
452485 freqs_cis : Optional [torch .Tensor ] = None ,
453486 optimized_attention = None ,
487+ past_key_value : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
454488 ):
455489 if self .transformer_type == 'gemma3' :
456490 if self .sliding_attention :
@@ -468,11 +502,12 @@ def forward(
468502 # Self Attention
469503 residual = x
470504 x = self .input_layernorm (x )
471- x = self .self_attn (
505+ x , present_key_value = self .self_attn (
472506 hidden_states = x ,
473507 attention_mask = attention_mask ,
474508 freqs_cis = freqs_cis ,
475509 optimized_attention = optimized_attention ,
510+ past_key_value = past_key_value ,
476511 )
477512
478513 x = self .post_attention_layernorm (x )
@@ -485,7 +520,7 @@ def forward(
485520 x = self .post_feedforward_layernorm (x )
486521 x = residual + x
487522
488- return x
523+ return x , present_key_value
489524
490525class Llama2_ (nn .Module ):
491526 def __init__ (self , config , device = None , dtype = None , ops = None ):
@@ -516,9 +551,10 @@ def __init__(self, config, device=None, dtype=None, ops=None):
516551 else :
517552 self .norm = None
518553
519- # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
554+ if config .lm_head :
555+ self .lm_head = ops .Linear (config .hidden_size , config .vocab_size , bias = False , device = device , dtype = dtype )
520556
521- def forward (self , x , attention_mask = None , embeds = None , num_tokens = None , intermediate_output = None , final_layer_norm_intermediate = True , dtype = None , position_ids = None , embeds_info = []):
557+ def forward (self , x , attention_mask = None , embeds = None , num_tokens = None , intermediate_output = None , final_layer_norm_intermediate = True , dtype = None , position_ids = None , embeds_info = [], past_key_values = None ):
522558 if embeds is not None :
523559 x = embeds
524560 else :
@@ -527,8 +563,13 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
527563 if self .normalize_in :
528564 x *= self .config .hidden_size ** 0.5
529565
566+ seq_len = x .shape [1 ]
567+ past_len = 0
568+ if past_key_values is not None and len (past_key_values ) > 0 :
569+ past_len = past_key_values [0 ][2 ]
570+
530571 if position_ids is None :
531- position_ids = torch .arange (0 , x . shape [ 1 ] , device = x .device ).unsqueeze (0 )
572+ position_ids = torch .arange (past_len , past_len + seq_len , device = x .device ).unsqueeze (0 )
532573
533574 freqs_cis = precompute_freqs_cis (self .config .head_dim ,
534575 position_ids ,
@@ -539,14 +580,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
539580
540581 mask = None
541582 if attention_mask is not None :
542- mask = 1.0 - attention_mask .to (x .dtype ).reshape ((attention_mask .shape [0 ], 1 , - 1 , attention_mask .shape [- 1 ])).expand (attention_mask .shape [0 ], 1 , attention_mask . shape [ - 1 ] , attention_mask .shape [- 1 ])
583+ mask = 1.0 - attention_mask .to (x .dtype ).reshape ((attention_mask .shape [0 ], 1 , - 1 , attention_mask .shape [- 1 ])).expand (attention_mask .shape [0 ], 1 , seq_len , attention_mask .shape [- 1 ])
543584 mask = mask .masked_fill (mask .to (torch .bool ), float ("-inf" ))
544585
545- causal_mask = torch .empty (x .shape [1 ], x .shape [1 ], dtype = x .dtype , device = x .device ).fill_ (float ("-inf" )).triu_ (1 )
546- if mask is not None :
547- mask += causal_mask
548- else :
549- mask = causal_mask
586+ if seq_len > 1 :
587+ causal_mask = torch .empty (past_len + seq_len , past_len + seq_len , dtype = x .dtype , device = x .device ).fill_ (float ("-inf" )).triu_ (1 )
588+ if mask is not None :
589+ mask += causal_mask
590+ else :
591+ mask = causal_mask
592+
550593 optimized_attention = optimized_attention_for_device (x .device , mask = mask is not None , small_input = True )
551594
552595 intermediate = None
@@ -562,16 +605,27 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
562605 elif intermediate_output < 0 :
563606 intermediate_output = len (self .layers ) + intermediate_output
564607
608+ next_key_values = []
565609 for i , layer in enumerate (self .layers ):
566610 if all_intermediate is not None :
567611 if only_layers is None or (i in only_layers ):
568612 all_intermediate .append (x .unsqueeze (1 ).clone ())
569- x = layer (
613+
614+ past_kv = None
615+ if past_key_values is not None :
616+ past_kv = past_key_values [i ] if len (past_key_values ) > 0 else []
617+
618+ x , current_kv = layer (
570619 x = x ,
571620 attention_mask = mask ,
572621 freqs_cis = freqs_cis ,
573622 optimized_attention = optimized_attention ,
623+ past_key_value = past_kv ,
574624 )
625+
626+ if current_kv is not None :
627+ next_key_values .append (current_kv )
628+
575629 if i == intermediate_output :
576630 intermediate = x .clone ()
577631
@@ -588,7 +642,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
588642 if intermediate is not None and final_layer_norm_intermediate and self .norm is not None :
589643 intermediate = self .norm (intermediate )
590644
591- return x , intermediate
645+ if len (next_key_values ) > 0 :
646+ return x , intermediate , next_key_values
647+ else :
648+ return x , intermediate
592649
593650
594651class Gemma3MultiModalProjector (torch .nn .Module ):
0 commit comments