33from dataclasses import dataclass
44from typing import Optional , Any , Tuple
55import math
6+ from tqdm import tqdm
7+ import comfy .utils
68
79from comfy .ldm .modules .attention import optimized_attention_for_device
810import comfy .model_management
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
313315 final_norm : bool = True
314316 lm_head : bool = False
315317
318+ GEMMA3_VISION_CONFIG = {"num_channels" : 3 , "hidden_act" : "gelu_pytorch_tanh" , "hidden_size" : 1152 , "image_size" : 896 , "intermediate_size" : 4304 , "model_type" : "siglip_vision_model" , "num_attention_heads" : 16 , "num_hidden_layers" : 27 , "patch_size" : 14 }
319+
320+ @dataclass
321+ class Gemma3_4B_Vision_Config (Gemma3_4B_Config ):
322+ vision_config = GEMMA3_VISION_CONFIG
323+ mm_tokens_per_image = 256
324+
316325@dataclass
317326class Gemma3_12B_Config :
318327 vocab_size : int = 262208
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
336345 rope_scale = [8.0 , 1.0 ]
337346 final_norm : bool = True
338347 lm_head : bool = False
339- vision_config = { "num_channels" : 3 , "hidden_act" : "gelu_pytorch_tanh" , "hidden_size" : 1152 , "image_size" : 896 , "intermediate_size" : 4304 , "model_type" : "siglip_vision_model" , "num_attention_heads" : 16 , "num_hidden_layers" : 27 , "patch_size" : 14 }
348+ vision_config = GEMMA3_VISION_CONFIG
340349 mm_tokens_per_image = 256
341350
342351class RMSNorm (nn .Module ):
@@ -441,8 +450,10 @@ def forward(
441450 freqs_cis : Optional [torch .Tensor ] = None ,
442451 optimized_attention = None ,
443452 past_key_value : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
453+ sliding_window : Optional [int ] = None ,
444454 ):
445455 batch_size , seq_length , _ = hidden_states .shape
456+
446457 xq = self .q_proj (hidden_states )
447458 xk = self .k_proj (hidden_states )
448459 xv = self .v_proj (hidden_states )
@@ -477,6 +488,11 @@ def forward(
477488 else :
478489 present_key_value = (xk , xv , index + num_tokens )
479490
491+ if sliding_window is not None and xk .shape [2 ] > sliding_window :
492+ xk = xk [:, :, - sliding_window :]
493+ xv = xv [:, :, - sliding_window :]
494+ attention_mask = attention_mask [..., - sliding_window :] if attention_mask is not None else None
495+
480496 xk = xk .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
481497 xv = xv .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
482498
@@ -559,10 +575,12 @@ def forward(
559575 optimized_attention = None ,
560576 past_key_value : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
561577 ):
578+ sliding_window = None
562579 if self .transformer_type == 'gemma3' :
563580 if self .sliding_attention :
581+ sliding_window = self .sliding_attention
564582 if x .shape [1 ] > self .sliding_attention :
565- sliding_mask = torch .full ((x .shape [1 ], x .shape [1 ]), float ( "-inf" ) , device = x .device , dtype = x .dtype )
583+ sliding_mask = torch .full ((x .shape [1 ], x .shape [1 ]), torch . finfo ( x . dtype ). min , device = x .device , dtype = x .dtype )
566584 sliding_mask .tril_ (diagonal = - self .sliding_attention )
567585 if attention_mask is not None :
568586 attention_mask = attention_mask + sliding_mask
@@ -581,6 +599,7 @@ def forward(
581599 freqs_cis = freqs_cis ,
582600 optimized_attention = optimized_attention ,
583601 past_key_value = past_key_value ,
602+ sliding_window = sliding_window ,
584603 )
585604
586605 x = self .post_attention_layernorm (x )
@@ -765,6 +784,104 @@ def set_input_embeddings(self, embeddings):
765784 def forward (self , input_ids , * args , ** kwargs ):
766785 return self .model (input_ids , * args , ** kwargs )
767786
787+ class BaseGenerate :
788+ def logits (self , x ):
789+ input = x [:, - 1 :]
790+ if hasattr (self .model , "lm_head" ):
791+ module = self .model .lm_head
792+ else :
793+ module = self .model .embed_tokens
794+
795+ offload_stream = None
796+ if module .comfy_cast_weights :
797+ weight , _ , offload_stream = comfy .ops .cast_bias_weight (module , input , offloadable = True )
798+ else :
799+ weight = self .model .embed_tokens .weight .to (x )
800+
801+ x = torch .nn .functional .linear (input , weight , None )
802+
803+ comfy .ops .uncast_bias_weight (module , weight , None , offload_stream )
804+ return x
805+
806+ def generate (self , embeds = None , do_sample = True , max_length = 256 , temperature = 1.0 , top_k = 50 , top_p = 0.9 , min_p = 0.0 , repetition_penalty = 1.0 , seed = 42 , stop_tokens = [], initial_tokens = [], execution_dtype = None , min_tokens = 0 ):
807+ device = embeds .device
808+ model_config = self .model .config
809+
810+ if execution_dtype is None :
811+ if comfy .model_management .should_use_bf16 (device ):
812+ execution_dtype = torch .bfloat16
813+ else :
814+ execution_dtype = torch .float32
815+ embeds = embeds .to (execution_dtype )
816+
817+ if embeds .ndim == 2 :
818+ embeds = embeds .unsqueeze (0 )
819+
820+ past_key_values = [] #kv_cache init
821+ max_cache_len = embeds .shape [1 ] + max_length
822+ for x in range (model_config .num_hidden_layers ):
823+ past_key_values .append ((torch .empty ([embeds .shape [0 ], model_config .num_key_value_heads , max_cache_len , model_config .head_dim ], device = device , dtype = execution_dtype ),
824+ torch .empty ([embeds .shape [0 ], model_config .num_key_value_heads , max_cache_len , model_config .head_dim ], device = device , dtype = execution_dtype ), 0 ))
825+
826+ generator = torch .Generator (device = device ).manual_seed (seed ) if do_sample else None
827+
828+ generated_token_ids = []
829+ pbar = comfy .utils .ProgressBar (max_length )
830+
831+ # Generation loop
832+ for step in tqdm (range (max_length ), desc = "Generating tokens" ):
833+ x , _ , past_key_values = self .model .forward (None , embeds = embeds , attention_mask = None , past_key_values = past_key_values )
834+ logits = self .logits (x )[:, - 1 ]
835+ next_token = self .sample_token (logits , temperature , top_k , top_p , min_p , repetition_penalty , initial_tokens + generated_token_ids , generator , do_sample = do_sample )
836+ token_id = next_token [0 ].item ()
837+ generated_token_ids .append (token_id )
838+
839+ embeds = self .model .embed_tokens (next_token ).to (execution_dtype )
840+ pbar .update (1 )
841+
842+ if token_id in stop_tokens :
843+ break
844+
845+ return generated_token_ids
846+
847+ def sample_token (self , logits , temperature , top_k , top_p , min_p , repetition_penalty , token_history , generator , do_sample = True ):
848+
849+ if not do_sample or temperature == 0.0 :
850+ return torch .argmax (logits , dim = - 1 , keepdim = True )
851+
852+ # Sampling mode
853+ if repetition_penalty != 1.0 :
854+ for i in range (logits .shape [0 ]):
855+ for token_id in set (token_history ):
856+ logits [i , token_id ] *= repetition_penalty if logits [i , token_id ] < 0 else 1 / repetition_penalty
857+
858+ if temperature != 1.0 :
859+ logits = logits / temperature
860+
861+ if top_k > 0 :
862+ indices_to_remove = logits < torch .topk (logits , top_k )[0 ][..., - 1 , None ]
863+ logits [indices_to_remove ] = torch .finfo (logits .dtype ).min
864+
865+ if min_p > 0.0 :
866+ probs_before_filter = torch .nn .functional .softmax (logits , dim = - 1 )
867+ top_probs , _ = probs_before_filter .max (dim = - 1 , keepdim = True )
868+ min_threshold = min_p * top_probs
869+ indices_to_remove = probs_before_filter < min_threshold
870+ logits [indices_to_remove ] = torch .finfo (logits .dtype ).min
871+
872+ if top_p < 1.0 :
873+ sorted_logits , sorted_indices = torch .sort (logits , descending = True )
874+ cumulative_probs = torch .cumsum (torch .nn .functional .softmax (sorted_logits , dim = - 1 ), dim = - 1 )
875+ sorted_indices_to_remove = cumulative_probs > top_p
876+ sorted_indices_to_remove [..., 0 ] = False
877+ indices_to_remove = torch .zeros_like (logits , dtype = torch .bool )
878+ indices_to_remove .scatter_ (1 , sorted_indices , sorted_indices_to_remove )
879+ logits [indices_to_remove ] = torch .finfo (logits .dtype ).min
880+
881+ probs = torch .nn .functional .softmax (logits , dim = - 1 )
882+
883+ return torch .multinomial (probs , num_samples = 1 , generator = generator )
884+
768885class BaseQwen3 :
769886 def logits (self , x ):
770887 input = x [:, - 1 :]
@@ -871,7 +988,7 @@ def __init__(self, config_dict, dtype, device, operations):
871988 self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
872989 self .dtype = dtype
873990
874- class Qwen25_7BVLI (BaseLlama , torch .nn .Module ):
991+ class Qwen25_7BVLI (BaseLlama , BaseGenerate , torch .nn .Module ):
875992 def __init__ (self , config_dict , dtype , device , operations ):
876993 super ().__init__ ()
877994 config = Qwen25_7BVLI_Config (** config_dict )
@@ -881,6 +998,9 @@ def __init__(self, config_dict, dtype, device, operations):
881998 self .visual = qwen_vl .Qwen2VLVisionTransformer (hidden_size = 1280 , output_hidden_size = config .hidden_size , device = device , dtype = dtype , ops = operations )
882999 self .dtype = dtype
8831000
1001+ # todo: should this be tied or not?
1002+ #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
1003+
8841004 def preprocess_embed (self , embed , device ):
8851005 if embed ["type" ] == "image" :
8861006 image , grid = qwen_vl .process_qwen2vl_images (embed ["data" ])
@@ -923,7 +1043,7 @@ def __init__(self, config_dict, dtype, device, operations):
9231043 self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
9241044 self .dtype = dtype
9251045
926- class Gemma3_4B (BaseLlama , torch .nn .Module ):
1046+ class Gemma3_4B (BaseLlama , BaseGenerate , torch .nn .Module ):
9271047 def __init__ (self , config_dict , dtype , device , operations ):
9281048 super ().__init__ ()
9291049 config = Gemma3_4B_Config (** config_dict )
@@ -932,7 +1052,25 @@ def __init__(self, config_dict, dtype, device, operations):
9321052 self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
9331053 self .dtype = dtype
9341054
935- class Gemma3_12B (BaseLlama , torch .nn .Module ):
1055+ class Gemma3_4B_Vision (BaseLlama , BaseGenerate , torch .nn .Module ):
1056+ def __init__ (self , config_dict , dtype , device , operations ):
1057+ super ().__init__ ()
1058+ config = Gemma3_4B_Vision_Config (** config_dict )
1059+ self .num_layers = config .num_hidden_layers
1060+
1061+ self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
1062+ self .dtype = dtype
1063+ self .multi_modal_projector = Gemma3MultiModalProjector (config , dtype , device , operations )
1064+ self .vision_model = comfy .clip_model .CLIPVision (config .vision_config , dtype , device , operations )
1065+ self .image_size = config .vision_config ["image_size" ]
1066+
1067+ def preprocess_embed (self , embed , device ):
1068+ if embed ["type" ] == "image" :
1069+ image = comfy .clip_model .clip_preprocess (embed ["data" ], size = self .image_size , mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ], crop = True )
1070+ return self .multi_modal_projector (self .vision_model (image .to (device , dtype = torch .float32 ))[0 ]), None
1071+ return None , None
1072+
1073+ class Gemma3_12B (BaseLlama , BaseGenerate , torch .nn .Module ):
9361074 def __init__ (self , config_dict , dtype , device , operations ):
9371075 super ().__init__ ()
9381076 config = Gemma3_12B_Config (** config_dict )
0 commit comments