Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,8 @@ def get_key_patches(self):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()

if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})

self.load_model()
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

Expand Down
4 changes: 2 additions & 2 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ def encode(self, tokens):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))

def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)

def parse_parentheses(string):
result = []
Expand Down
2 changes: 2 additions & 0 deletions comfy/text_encoders/anima.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def untokenize(self, token_weight_pair):
def state_dict(self):
return {}

def decode(self, token_ids, **kwargs):
return self.qwen3_06b.decode(token_ids, **kwargs)

class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
Expand Down
22 changes: 17 additions & 5 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Qwen3_06BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Qwen3_06B_ACE15_Config:
Expand All @@ -128,6 +129,7 @@ class Qwen3_06B_ACE15_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Qwen3_2B_ACE15_lm_Config:
Expand All @@ -151,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Qwen3_4B_ACE15_lm_Config:
Expand All @@ -174,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Qwen3_4BConfig:
Expand All @@ -197,6 +201,7 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Qwen3_8BConfig:
Expand All @@ -220,6 +225,7 @@ class Qwen3_8BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]

@dataclass
class Ovis25_2BConfig:
Expand Down Expand Up @@ -290,6 +296,7 @@ class Gemma2_2B_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1]

@dataclass
class Gemma3_4B_Config:
Expand All @@ -314,6 +321,7 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1, 106]

GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}

Expand Down Expand Up @@ -347,6 +355,7 @@ class Gemma3_12B_Config:
lm_head: bool = False
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
stop_tokens = [1, 106]

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
Expand Down Expand Up @@ -803,10 +812,13 @@ def logits(self, x):
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x

def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config

if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens

if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
Expand Down Expand Up @@ -925,7 +937,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
Expand All @@ -952,7 +964,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
Expand All @@ -970,7 +982,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
Expand Down Expand Up @@ -1034,7 +1046,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed

return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)

class Gemma2_2B(BaseLlama, torch.nn.Module):
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)
Expand Down
6 changes: 0 additions & 6 deletions comfy/text_encoders/lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])

class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
Expand All @@ -43,9 +40,6 @@ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, atten

super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])

class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
Expand Down