Skip to content

Commit 6d11cc7

Browse files
authored
feat: Add basic text generation support with native models, initially supporting Gemma3 (Comfy-Org#12392)
1 parent f262444 commit 6d11cc7

File tree

9 files changed

+501
-32
lines changed

9 files changed

+501
-32
lines changed

comfy/sd.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,19 @@ def load_model(self, tokens={}):
423423
def get_key_patches(self):
424424
return self.patcher.get_key_patches()
425425

426+
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
427+
self.cond_stage_model.reset_clip_options()
428+
429+
if self.layer_idx is not None:
430+
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
431+
432+
self.load_model()
433+
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
434+
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
435+
436+
def decode(self, token_ids, skip_special_tokens=True):
437+
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
438+
426439
class VAE:
427440
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
428441
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@@ -1182,6 +1195,7 @@ class TEModel(Enum):
11821195
JINA_CLIP_2 = 19
11831196
QWEN3_8B = 20
11841197
QWEN3_06B = 21
1198+
GEMMA_3_4B_VISION = 22
11851199

11861200

11871201
def detect_te_model(sd):
@@ -1210,7 +1224,10 @@ def detect_te_model(sd):
12101224
if 'model.layers.47.self_attn.q_norm.weight' in sd:
12111225
return TEModel.GEMMA_3_12B
12121226
if 'model.layers.0.self_attn.q_norm.weight' in sd:
1213-
return TEModel.GEMMA_3_4B
1227+
if 'vision_model.embeddings.patch_embedding.weight' in sd:
1228+
return TEModel.GEMMA_3_4B_VISION
1229+
else:
1230+
return TEModel.GEMMA_3_4B
12141231
return TEModel.GEMMA_2_2B
12151232
if 'model.layers.0.self_attn.k_proj.bias' in sd:
12161233
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1270,6 +1287,8 @@ class EmptyClass:
12701287
else:
12711288
if "text_projection" in clip_data[i]:
12721289
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
1290+
if "lm_head.weight" in clip_data[i]:
1291+
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
12731292

12741293
tokenizer_data = {}
12751294
clip_target = EmptyClass()
@@ -1335,6 +1354,14 @@ class EmptyClass:
13351354
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
13361355
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
13371356
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
1357+
elif te_model == TEModel.GEMMA_3_4B_VISION:
1358+
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
1359+
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
1360+
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
1361+
elif te_model == TEModel.GEMMA_3_12B:
1362+
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
1363+
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
1364+
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
13381365
elif te_model == TEModel.LLAMA3_8:
13391366
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
13401367
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)

comfy/sd1_clip.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,15 @@ def encode(self, tokens):
308308
def load_sd(self, sd):
309309
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
310310

311+
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
312+
if isinstance(tokens, dict):
313+
tokens_only = next(iter(tokens.values())) # todo: get this better?
314+
else:
315+
tokens_only = tokens
316+
tokens_only = [[t[0] for t in b] for b in tokens_only]
317+
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
318+
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
319+
311320
def parse_parentheses(string):
312321
result = []
313322
current_item = ""
@@ -663,6 +672,9 @@ def untokenize(self, token_weight_pair):
663672
def state_dict(self):
664673
return {}
665674

675+
def decode(self, token_ids, skip_special_tokens=True):
676+
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
677+
666678
class SD1Tokenizer:
667679
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
668680
if name is not None:
@@ -686,6 +698,9 @@ def untokenize(self, token_weight_pair):
686698
def state_dict(self):
687699
return getattr(self, self.clip).state_dict()
688700

701+
def decode(self, token_ids, skip_special_tokens=True):
702+
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
703+
689704
class SD1CheckpointClipModel(SDClipModel):
690705
def __init__(self, device="cpu", dtype=None, model_options={}):
691706
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@@ -722,3 +737,6 @@ def encode_token_weights(self, token_weight_pairs):
722737

723738
def load_sd(self, sd):
724739
return getattr(self, self.clip).load_sd(sd)
740+
741+
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
742+
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

comfy/text_encoders/llama.py

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from typing import Optional, Any, Tuple
55
import math
6+
from tqdm import tqdm
7+
import comfy.utils
68

79
from comfy.ldm.modules.attention import optimized_attention_for_device
810
import comfy.model_management
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
313315
final_norm: bool = True
314316
lm_head: bool = False
315317

318+
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
319+
320+
@dataclass
321+
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
322+
vision_config = GEMMA3_VISION_CONFIG
323+
mm_tokens_per_image = 256
324+
316325
@dataclass
317326
class Gemma3_12B_Config:
318327
vocab_size: int = 262208
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
336345
rope_scale = [8.0, 1.0]
337346
final_norm: bool = True
338347
lm_head: bool = False
339-
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
348+
vision_config = GEMMA3_VISION_CONFIG
340349
mm_tokens_per_image = 256
341350

342351
class RMSNorm(nn.Module):
@@ -441,8 +450,10 @@ def forward(
441450
freqs_cis: Optional[torch.Tensor] = None,
442451
optimized_attention=None,
443452
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
453+
sliding_window: Optional[int] = None,
444454
):
445455
batch_size, seq_length, _ = hidden_states.shape
456+
446457
xq = self.q_proj(hidden_states)
447458
xk = self.k_proj(hidden_states)
448459
xv = self.v_proj(hidden_states)
@@ -477,6 +488,11 @@ def forward(
477488
else:
478489
present_key_value = (xk, xv, index + num_tokens)
479490

491+
if sliding_window is not None and xk.shape[2] > sliding_window:
492+
xk = xk[:, :, -sliding_window:]
493+
xv = xv[:, :, -sliding_window:]
494+
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
495+
480496
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
481497
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
482498

@@ -559,10 +575,12 @@ def forward(
559575
optimized_attention=None,
560576
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
561577
):
578+
sliding_window = None
562579
if self.transformer_type == 'gemma3':
563580
if self.sliding_attention:
581+
sliding_window = self.sliding_attention
564582
if x.shape[1] > self.sliding_attention:
565-
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
583+
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
566584
sliding_mask.tril_(diagonal=-self.sliding_attention)
567585
if attention_mask is not None:
568586
attention_mask = attention_mask + sliding_mask
@@ -581,6 +599,7 @@ def forward(
581599
freqs_cis=freqs_cis,
582600
optimized_attention=optimized_attention,
583601
past_key_value=past_key_value,
602+
sliding_window=sliding_window,
584603
)
585604

586605
x = self.post_attention_layernorm(x)
@@ -765,6 +784,104 @@ def set_input_embeddings(self, embeddings):
765784
def forward(self, input_ids, *args, **kwargs):
766785
return self.model(input_ids, *args, **kwargs)
767786

787+
class BaseGenerate:
788+
def logits(self, x):
789+
input = x[:, -1:]
790+
if hasattr(self.model, "lm_head"):
791+
module = self.model.lm_head
792+
else:
793+
module = self.model.embed_tokens
794+
795+
offload_stream = None
796+
if module.comfy_cast_weights:
797+
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
798+
else:
799+
weight = self.model.embed_tokens.weight.to(x)
800+
801+
x = torch.nn.functional.linear(input, weight, None)
802+
803+
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
804+
return x
805+
806+
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
807+
device = embeds.device
808+
model_config = self.model.config
809+
810+
if execution_dtype is None:
811+
if comfy.model_management.should_use_bf16(device):
812+
execution_dtype = torch.bfloat16
813+
else:
814+
execution_dtype = torch.float32
815+
embeds = embeds.to(execution_dtype)
816+
817+
if embeds.ndim == 2:
818+
embeds = embeds.unsqueeze(0)
819+
820+
past_key_values = [] #kv_cache init
821+
max_cache_len = embeds.shape[1] + max_length
822+
for x in range(model_config.num_hidden_layers):
823+
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
824+
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
825+
826+
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
827+
828+
generated_token_ids = []
829+
pbar = comfy.utils.ProgressBar(max_length)
830+
831+
# Generation loop
832+
for step in tqdm(range(max_length), desc="Generating tokens"):
833+
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
834+
logits = self.logits(x)[:, -1]
835+
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
836+
token_id = next_token[0].item()
837+
generated_token_ids.append(token_id)
838+
839+
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
840+
pbar.update(1)
841+
842+
if token_id in stop_tokens:
843+
break
844+
845+
return generated_token_ids
846+
847+
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
848+
849+
if not do_sample or temperature == 0.0:
850+
return torch.argmax(logits, dim=-1, keepdim=True)
851+
852+
# Sampling mode
853+
if repetition_penalty != 1.0:
854+
for i in range(logits.shape[0]):
855+
for token_id in set(token_history):
856+
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
857+
858+
if temperature != 1.0:
859+
logits = logits / temperature
860+
861+
if top_k > 0:
862+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
863+
logits[indices_to_remove] = torch.finfo(logits.dtype).min
864+
865+
if min_p > 0.0:
866+
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
867+
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
868+
min_threshold = min_p * top_probs
869+
indices_to_remove = probs_before_filter < min_threshold
870+
logits[indices_to_remove] = torch.finfo(logits.dtype).min
871+
872+
if top_p < 1.0:
873+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
874+
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
875+
sorted_indices_to_remove = cumulative_probs > top_p
876+
sorted_indices_to_remove[..., 0] = False
877+
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
878+
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
879+
logits[indices_to_remove] = torch.finfo(logits.dtype).min
880+
881+
probs = torch.nn.functional.softmax(logits, dim=-1)
882+
883+
return torch.multinomial(probs, num_samples=1, generator=generator)
884+
768885
class BaseQwen3:
769886
def logits(self, x):
770887
input = x[:, -1:]
@@ -871,7 +988,7 @@ def __init__(self, config_dict, dtype, device, operations):
871988
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
872989
self.dtype = dtype
873990

874-
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
991+
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
875992
def __init__(self, config_dict, dtype, device, operations):
876993
super().__init__()
877994
config = Qwen25_7BVLI_Config(**config_dict)
@@ -881,6 +998,9 @@ def __init__(self, config_dict, dtype, device, operations):
881998
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
882999
self.dtype = dtype
8831000

1001+
# todo: should this be tied or not?
1002+
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
1003+
8841004
def preprocess_embed(self, embed, device):
8851005
if embed["type"] == "image":
8861006
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@@ -923,7 +1043,7 @@ def __init__(self, config_dict, dtype, device, operations):
9231043
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
9241044
self.dtype = dtype
9251045

926-
class Gemma3_4B(BaseLlama, torch.nn.Module):
1046+
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
9271047
def __init__(self, config_dict, dtype, device, operations):
9281048
super().__init__()
9291049
config = Gemma3_4B_Config(**config_dict)
@@ -932,7 +1052,25 @@ def __init__(self, config_dict, dtype, device, operations):
9321052
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
9331053
self.dtype = dtype
9341054

935-
class Gemma3_12B(BaseLlama, torch.nn.Module):
1055+
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
1056+
def __init__(self, config_dict, dtype, device, operations):
1057+
super().__init__()
1058+
config = Gemma3_4B_Vision_Config(**config_dict)
1059+
self.num_layers = config.num_hidden_layers
1060+
1061+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
1062+
self.dtype = dtype
1063+
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
1064+
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
1065+
self.image_size = config.vision_config["image_size"]
1066+
1067+
def preprocess_embed(self, embed, device):
1068+
if embed["type"] == "image":
1069+
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
1070+
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
1071+
return None, None
1072+
1073+
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
9361074
def __init__(self, config_dict, dtype, device, operations):
9371075
super().__init__()
9381076
config = Gemma3_12B_Config(**config_dict)

0 commit comments

Comments
 (0)