Skip to content

Commit 98b25d4

Browse files
committed
Implement Jina CLIP v2
1 parent 16d85ea commit 98b25d4

File tree

6 files changed

+296
-4
lines changed

6 files changed

+296
-4
lines changed

comfy/model_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ def extra_conds(self, **kwargs):
11101110
if 'num_tokens' not in out:
11111111
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
11121112

1113-
clip_text_pooled = kwargs["pooled_output"] # Newbie
1113+
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
11141114
if clip_text_pooled is not None:
11151115
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
11161116

comfy/model_detection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
430430
dit_config["rope_theta"] = 10000.0
431431
dit_config["ffn_dim_multiplier"] = 4.0
432432
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
433-
if ctd_weight is not None:
433+
if ctd_weight is not None: # NewBie
434434
dit_config["clip_text_dim"] = ctd_weight.shape[0]
435+
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
435436
elif dit_config["dim"] == 3840: # Z image
436437
dit_config["n_heads"] = 30
437438
dit_config["n_kv_heads"] = 30

comfy/sd.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
import comfy.text_encoders.z_image
5656
import comfy.text_encoders.ovis
5757
import comfy.text_encoders.kandinsky5
58+
import comfy.text_encoders.jina_clip_2
59+
import comfy.text_encoders.newbie
5860

5961
import comfy.model_patcher
6062
import comfy.lora
@@ -992,6 +994,7 @@ class CLIPType(Enum):
992994
OVIS = 21
993995
KANDINSKY5 = 22
994996
KANDINSKY5_IMAGE = 23
997+
NEWBIE = 24
995998

996999

9971000
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -1022,6 +1025,7 @@ class TEModel(Enum):
10221025
MISTRAL3_24B_PRUNED_FLUX2 = 15
10231026
QWEN3_4B = 16
10241027
QWEN3_2B = 17
1028+
JINA_CLIP_2 = 18
10251029

10261030

10271031
def detect_te_model(sd):
@@ -1031,6 +1035,8 @@ def detect_te_model(sd):
10311035
return TEModel.CLIP_H
10321036
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
10331037
return TEModel.CLIP_L
1038+
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
1039+
return TEModel.JINA_CLIP_2
10341040
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
10351041
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
10361042
if weight.shape[-1] == 4096:
@@ -1191,6 +1197,9 @@ class EmptyClass:
11911197
elif te_model == TEModel.QWEN3_2B:
11921198
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
11931199
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
1200+
elif te_model == TEModel.JINA_CLIP_2:
1201+
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
1202+
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
11941203
else:
11951204
# clip_l
11961205
if clip_type == CLIPType.SD3:
@@ -1246,6 +1255,17 @@ class EmptyClass:
12461255
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
12471256
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
12481257
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
1258+
elif clip_type == CLIPType.NEWBIE:
1259+
clip_target.clip = comfy.text_encoders.newbie.NewBieClipModel
1260+
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
1261+
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
1262+
clip_data_gemma = clip_data[0]
1263+
clip_data_jina = clip_data[1]
1264+
else:
1265+
clip_data_gemma = clip_data[1]
1266+
clip_data_jina = clip_data[0]
1267+
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
1268+
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
12491269
else:
12501270
clip_target.clip = sdxl_clip.SDXLClipModel
12511271
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

comfy/text_encoders/jina_clip_2.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
2+
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
3+
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
4+
5+
from dataclasses import dataclass
6+
7+
import torch
8+
from torch import nn as nn
9+
from torch.nn import functional as F
10+
11+
import comfy.model_management
12+
import comfy.ops
13+
from comfy import sd1_clip
14+
from .spiece_tokenizer import SPieceTokenizer
15+
16+
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
17+
def __init__(self, embedding_directory=None, tokenizer_data={}):
18+
tokenizer = tokenizer_data.get("spiece_model", None)
19+
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
20+
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
21+
22+
def state_dict(self):
23+
return {"spiece_model": self.tokenizer.serialize_model()}
24+
25+
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
26+
def __init__(self, embedding_directory=None, tokenizer_data={}):
27+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
28+
29+
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
30+
@dataclass
31+
class XLMRobertaConfig:
32+
vocab_size: int = 250002
33+
type_vocab_size: int = 1
34+
hidden_size: int = 1024
35+
num_hidden_layers: int = 24
36+
num_attention_heads: int = 16
37+
rotary_emb_base: float = 20000.0
38+
intermediate_size: int = 4096
39+
hidden_act: str = "gelu"
40+
hidden_dropout_prob: float = 0.1
41+
attention_probs_dropout_prob: float = 0.1
42+
layer_norm_eps: float = 1e-05
43+
bos_token_id: int = 0
44+
eos_token_id: int = 2
45+
pad_token_id: int = 1
46+
47+
class XLMRobertaEmbeddings(nn.Module):
48+
def __init__(self, config, device=None, dtype=None, ops=None):
49+
super().__init__()
50+
embed_dim = config.hidden_size
51+
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
52+
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
53+
54+
def forward(self, input_ids=None, embeddings=None):
55+
if input_ids is not None and embeddings is None:
56+
embeddings = self.word_embeddings(input_ids)
57+
58+
if embeddings is not None:
59+
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
60+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
61+
embeddings = embeddings + token_type_embeddings
62+
return embeddings
63+
64+
class RotaryEmbedding(nn.Module):
65+
def __init__(self, dim, base, device=None):
66+
super().__init__()
67+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
68+
self.register_buffer("inv_freq", inv_freq, persistent=False)
69+
self._seq_len_cached = 0
70+
self._cos_cached = None
71+
self._sin_cached = None
72+
73+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
74+
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
75+
self._seq_len_cached = seqlen
76+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
77+
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
78+
emb = torch.cat((freqs, freqs), dim=-1)
79+
self._cos_cached = emb.cos().to(dtype)
80+
self._sin_cached = emb.sin().to(dtype)
81+
82+
def forward(self, q, k):
83+
batch, seqlen, heads, head_dim = q.shape
84+
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
85+
86+
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
87+
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
88+
89+
def rotate_half(x):
90+
size = x.shape[-1] // 2
91+
x1, x2 = x[..., :size], x[..., size:]
92+
return torch.cat((-x2, x1), dim=-1)
93+
94+
q_embed = (q * cos) + (rotate_half(q) * sin)
95+
k_embed = (k * cos) + (rotate_half(k) * sin)
96+
return q_embed, k_embed
97+
98+
class MHA(nn.Module):
99+
def __init__(self, config, device=None, dtype=None, ops=None):
100+
super().__init__()
101+
embed_dim = config.hidden_size
102+
self.num_heads = config.num_attention_heads
103+
self.head_dim = embed_dim // config.num_attention_heads
104+
105+
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
106+
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
107+
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
108+
109+
def forward(self, x, mask=None, optimized_attention=None):
110+
qkv = self.Wqkv(x)
111+
batch_size, seq_len, _ = qkv.shape
112+
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
113+
q, k, v = qkv.unbind(2)
114+
115+
q, k = self.rotary_emb(q, k)
116+
117+
# NHD -> HND
118+
q = q.transpose(1, 2)
119+
k = k.transpose(1, 2)
120+
v = v.transpose(1, 2)
121+
122+
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
123+
return self.out_proj(out)
124+
125+
class MLP(nn.Module):
126+
def __init__(self, config, device=None, dtype=None, ops=None):
127+
super().__init__()
128+
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
129+
self.activation = F.gelu
130+
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
131+
132+
def forward(self, x):
133+
x = self.fc1(x)
134+
x = self.activation(x)
135+
x = self.fc2(x)
136+
return x
137+
138+
class Block(nn.Module):
139+
def __init__(self, config, device=None, dtype=None, ops=None):
140+
super().__init__()
141+
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
142+
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
143+
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
144+
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
145+
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
146+
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
147+
148+
def forward(self, hidden_states, mask=None, optimized_attention=None):
149+
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
150+
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
151+
mlp_out = self.mlp(hidden_states)
152+
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
153+
return hidden_states
154+
155+
class XLMRobertaEncoder(nn.Module):
156+
def __init__(self, config, device=None, dtype=None, ops=None):
157+
super().__init__()
158+
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
159+
160+
def forward(self, hidden_states, attention_mask=None):
161+
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
162+
for layer in self.layers:
163+
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
164+
return hidden_states
165+
166+
class XLMRobertaModel_(nn.Module):
167+
def __init__(self, config, device=None, dtype=None, ops=None):
168+
super().__init__()
169+
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
170+
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
171+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
172+
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
173+
174+
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
175+
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
176+
x = self.emb_ln(x)
177+
x = self.emb_drop(x)
178+
179+
mask = None
180+
if attention_mask is not None:
181+
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
182+
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
183+
184+
sequence_output = self.encoder(x, attention_mask=mask)
185+
186+
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
187+
pooled_output = None
188+
if attention_mask is None:
189+
pooled_output = sequence_output.mean(dim=1)
190+
else:
191+
attention_mask = attention_mask.to(sequence_output.dtype)
192+
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
193+
194+
# Intermediate output is not yet implemented, use None for placeholder
195+
return sequence_output, None, pooled_output
196+
197+
class XLMRobertaModel(nn.Module):
198+
def __init__(self, config_dict, dtype, device, operations):
199+
super().__init__()
200+
self.config = XLMRobertaConfig(**config_dict)
201+
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
202+
self.num_layers = self.config.num_hidden_layers
203+
204+
def get_input_embeddings(self):
205+
return self.model.embeddings.word_embeddings
206+
207+
def set_input_embeddings(self, embeddings):
208+
self.model.embeddings.word_embeddings = embeddings
209+
210+
def forward(self, *args, **kwargs):
211+
return self.model(*args, **kwargs)
212+
213+
class JinaClip2TextModel(sd1_clip.SDClipModel):
214+
def __init__(self, device="cpu", dtype=None, model_options={}):
215+
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
216+
217+
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
218+
def __init__(self, device="cpu", dtype=None, model_options={}):
219+
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

comfy/text_encoders/newbie.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
import comfy.model_management
4+
import comfy.text_encoders.jina_clip_2
5+
import comfy.text_encoders.lumina2
6+
7+
class NewBieTokenizer:
8+
def __init__(self, embedding_directory=None, tokenizer_data={}):
9+
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
10+
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
11+
12+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
13+
out = {}
14+
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
15+
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
16+
return out
17+
18+
def untokenize(self, token_weight_pair):
19+
raise NotImplementedError
20+
21+
def state_dict(self):
22+
return {}
23+
24+
class NewBieClipModel(torch.nn.Module):
25+
def __init__(self, device="cpu", dtype=None, model_options={}):
26+
super().__init__()
27+
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype, model_options=model_options)
28+
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
29+
self.dtypes = {dtype}
30+
31+
def set_clip_options(self, options):
32+
self.gemma.set_clip_options(options)
33+
self.jina.set_clip_options(options)
34+
35+
def reset_clip_options(self):
36+
self.gemma.reset_clip_options()
37+
self.jina.reset_clip_options()
38+
39+
def encode_token_weights(self, token_weight_pairs):
40+
token_weight_pairs_gemma = token_weight_pairs["gemma"]
41+
token_weight_pairs_jina = token_weight_pairs["jina"]
42+
43+
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
44+
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
45+
46+
return gemma_out, jina_pooled, gemma_extra
47+
48+
def load_sd(self, sd):
49+
if "model.layers.0.self_attn.q_norm.weight" in sd:
50+
return self.gemma.load_sd(sd)
51+
else:
52+
return self.jina.load_sd(sd)

nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ class DualCLIPLoader:
970970
def INPUT_TYPES(s):
971971
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
972972
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
973-
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
973+
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
974974
},
975975
"optional": {
976976
"device": (["default", "cpu"], {"advanced": True}),
@@ -980,7 +980,7 @@ def INPUT_TYPES(s):
980980

981981
CATEGORY = "advanced/loaders"
982982

983-
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
983+
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
984984

985985
def load_clip(self, clip_name1, clip_name2, type, device="default"):
986986
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)

0 commit comments

Comments
 (0)