Skip to content

Commit 3b83223

Browse files
Flux2 Klein support. (Comfy-Org#11890)
1 parent be518db commit 3b83223

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed

comfy/sd.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ class CLIPType(Enum):
10141014
KANDINSKY5 = 22
10151015
KANDINSKY5_IMAGE = 23
10161016
NEWBIE = 24
1017+
FLUX2 = 25
10171018

10181019

10191020
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -1046,6 +1047,7 @@ class TEModel(Enum):
10461047
QWEN3_2B = 17
10471048
GEMMA_3_12B = 18
10481049
JINA_CLIP_2 = 19
1050+
QWEN3_8B = 20
10491051

10501052

10511053
def detect_te_model(sd):
@@ -1089,6 +1091,8 @@ def detect_te_model(sd):
10891091
return TEModel.QWEN3_4B
10901092
elif weight.shape[0] == 2048:
10911093
return TEModel.QWEN3_2B
1094+
elif weight.shape[0] == 4096:
1095+
return TEModel.QWEN3_8B
10921096
if weight.shape[0] == 5120:
10931097
if "model.layers.39.post_attention_layernorm.weight" in sd:
10941098
return TEModel.MISTRAL3_24B
@@ -1214,11 +1218,18 @@ class EmptyClass:
12141218
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
12151219
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
12161220
elif te_model == TEModel.QWEN3_4B:
1217-
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
1218-
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
1221+
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
1222+
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
1223+
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer
1224+
else:
1225+
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
1226+
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
12191227
elif te_model == TEModel.QWEN3_2B:
12201228
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
12211229
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
1230+
elif te_model == TEModel.QWEN3_8B:
1231+
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
1232+
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
12221233
elif te_model == TEModel.JINA_CLIP_2:
12231234
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
12241235
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper

comfy/text_encoders/flux.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import comfy.text_encoders.sd3_clip
44
import comfy.text_encoders.llama
55
import comfy.model_management
6-
from transformers import T5TokenizerFast, LlamaTokenizerFast
6+
from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer
77
import torch
88
import os
99
import json
@@ -172,3 +172,60 @@ def __init__(self, device="cpu", dtype=None, model_options={}):
172172
model_options["num_layers"] = 30
173173
super().__init__(device=device, dtype=dtype, model_options=model_options)
174174
return Flux2TEModel_
175+
176+
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
177+
def __init__(self, embedding_directory=None, tokenizer_data={}):
178+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
179+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
180+
181+
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
182+
def __init__(self, embedding_directory=None, tokenizer_data={}):
183+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
184+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
185+
186+
class KleinTokenizer(sd1_clip.SD1Tokenizer):
187+
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
188+
if name == "qwen3_4b":
189+
tokenizer = Qwen3Tokenizer
190+
elif name == "qwen3_8b":
191+
tokenizer = Qwen3Tokenizer8B
192+
193+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer)
194+
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
195+
196+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
197+
if llama_template is None:
198+
llama_text = self.llama_template.format(text)
199+
else:
200+
llama_text = llama_template.format(text)
201+
202+
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
203+
return tokens
204+
205+
class KleinTokenizer8B(KleinTokenizer):
206+
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"):
207+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name)
208+
209+
class Qwen3_4BModel(sd1_clip.SDClipModel):
210+
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
211+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
212+
213+
class Qwen3_8BModel(sd1_clip.SDClipModel):
214+
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
215+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
216+
217+
def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"):
218+
if model_type == "qwen3_4b":
219+
model = Qwen3_4BModel
220+
elif model_type == "qwen3_8b":
221+
model = Qwen3_8BModel
222+
223+
class Flux2TEModel_(Flux2TEModel):
224+
def __init__(self, device="cpu", dtype=None, model_options={}):
225+
if llama_quantization_metadata is not None:
226+
model_options = model_options.copy()
227+
model_options["quantization_metadata"] = llama_quantization_metadata
228+
if dtype_llama is not None:
229+
dtype = dtype_llama
230+
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
231+
return Flux2TEModel_

comfy/text_encoders/llama.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,28 @@ class Qwen3_4BConfig:
9999
rope_scale = None
100100
final_norm: bool = True
101101

102+
@dataclass
103+
class Qwen3_8BConfig:
104+
vocab_size: int = 151936
105+
hidden_size: int = 4096
106+
intermediate_size: int = 12288
107+
num_hidden_layers: int = 36
108+
num_attention_heads: int = 32
109+
num_key_value_heads: int = 8
110+
max_position_embeddings: int = 40960
111+
rms_norm_eps: float = 1e-6
112+
rope_theta: float = 1000000.0
113+
transformer_type: str = "llama"
114+
head_dim = 128
115+
rms_norm_add = False
116+
mlp_activation = "silu"
117+
qkv_bias = False
118+
rope_dims = None
119+
q_norm = "gemma3"
120+
k_norm = "gemma3"
121+
rope_scale = None
122+
final_norm: bool = True
123+
102124
@dataclass
103125
class Ovis25_2BConfig:
104126
vocab_size: int = 151936
@@ -628,6 +650,15 @@ def __init__(self, config_dict, dtype, device, operations):
628650
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
629651
self.dtype = dtype
630652

653+
class Qwen3_8B(BaseLlama, torch.nn.Module):
654+
def __init__(self, config_dict, dtype, device, operations):
655+
super().__init__()
656+
config = Qwen3_8BConfig(**config_dict)
657+
self.num_layers = config.num_hidden_layers
658+
659+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
660+
self.dtype = dtype
661+
631662
class Ovis25_2B(BaseLlama, torch.nn.Module):
632663
def __init__(self, config_dict, dtype, device, operations):
633664
super().__init__()

0 commit comments

Comments
 (0)