|
| 1 | +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: |
| 2 | +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py |
| 3 | +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import nn as nn |
| 9 | +from torch.nn import functional as F |
| 10 | + |
| 11 | +import comfy.model_management |
| 12 | +import comfy.ops |
| 13 | +from comfy import sd1_clip |
| 14 | +from .spiece_tokenizer import SPieceTokenizer |
| 15 | + |
| 16 | +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): |
| 17 | + def __init__(self, embedding_directory=None, tokenizer_data={}): |
| 18 | + tokenizer = tokenizer_data.get("spiece_model", None) |
| 19 | + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 |
| 20 | + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) |
| 21 | + |
| 22 | + def state_dict(self): |
| 23 | + return {"spiece_model": self.tokenizer.serialize_model()} |
| 24 | + |
| 25 | +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): |
| 26 | + def __init__(self, embedding_directory=None, tokenizer_data={}): |
| 27 | + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") |
| 28 | + |
| 29 | +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json |
| 30 | +@dataclass |
| 31 | +class XLMRobertaConfig: |
| 32 | + vocab_size: int = 250002 |
| 33 | + type_vocab_size: int = 1 |
| 34 | + hidden_size: int = 1024 |
| 35 | + num_hidden_layers: int = 24 |
| 36 | + num_attention_heads: int = 16 |
| 37 | + rotary_emb_base: float = 20000.0 |
| 38 | + intermediate_size: int = 4096 |
| 39 | + hidden_act: str = "gelu" |
| 40 | + hidden_dropout_prob: float = 0.1 |
| 41 | + attention_probs_dropout_prob: float = 0.1 |
| 42 | + layer_norm_eps: float = 1e-05 |
| 43 | + bos_token_id: int = 0 |
| 44 | + eos_token_id: int = 2 |
| 45 | + pad_token_id: int = 1 |
| 46 | + |
| 47 | +class XLMRobertaEmbeddings(nn.Module): |
| 48 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 49 | + super().__init__() |
| 50 | + embed_dim = config.hidden_size |
| 51 | + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) |
| 52 | + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) |
| 53 | + |
| 54 | + def forward(self, input_ids=None, embeddings=None): |
| 55 | + if input_ids is not None and embeddings is None: |
| 56 | + embeddings = self.word_embeddings(input_ids) |
| 57 | + |
| 58 | + if embeddings is not None: |
| 59 | + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) |
| 60 | + token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| 61 | + embeddings = embeddings + token_type_embeddings |
| 62 | + return embeddings |
| 63 | + |
| 64 | +class RotaryEmbedding(nn.Module): |
| 65 | + def __init__(self, dim, base, device=None): |
| 66 | + super().__init__() |
| 67 | + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) |
| 68 | + self.register_buffer("inv_freq", inv_freq, persistent=False) |
| 69 | + self._seq_len_cached = 0 |
| 70 | + self._cos_cached = None |
| 71 | + self._sin_cached = None |
| 72 | + |
| 73 | + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
| 74 | + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: |
| 75 | + self._seq_len_cached = seqlen |
| 76 | + t = torch.arange(seqlen, device=device, dtype=torch.float32) |
| 77 | + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) |
| 78 | + emb = torch.cat((freqs, freqs), dim=-1) |
| 79 | + self._cos_cached = emb.cos().to(dtype) |
| 80 | + self._sin_cached = emb.sin().to(dtype) |
| 81 | + |
| 82 | + def forward(self, q, k): |
| 83 | + batch, seqlen, heads, head_dim = q.shape |
| 84 | + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) |
| 85 | + |
| 86 | + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) |
| 87 | + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) |
| 88 | + |
| 89 | + def rotate_half(x): |
| 90 | + size = x.shape[-1] // 2 |
| 91 | + x1, x2 = x[..., :size], x[..., size:] |
| 92 | + return torch.cat((-x2, x1), dim=-1) |
| 93 | + |
| 94 | + q_embed = (q * cos) + (rotate_half(q) * sin) |
| 95 | + k_embed = (k * cos) + (rotate_half(k) * sin) |
| 96 | + return q_embed, k_embed |
| 97 | + |
| 98 | +class MHA(nn.Module): |
| 99 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 100 | + super().__init__() |
| 101 | + embed_dim = config.hidden_size |
| 102 | + self.num_heads = config.num_attention_heads |
| 103 | + self.head_dim = embed_dim // config.num_attention_heads |
| 104 | + |
| 105 | + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) |
| 106 | + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) |
| 107 | + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) |
| 108 | + |
| 109 | + def forward(self, x, mask=None, optimized_attention=None): |
| 110 | + qkv = self.Wqkv(x) |
| 111 | + batch_size, seq_len, _ = qkv.shape |
| 112 | + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) |
| 113 | + q, k, v = qkv.unbind(2) |
| 114 | + |
| 115 | + q, k = self.rotary_emb(q, k) |
| 116 | + |
| 117 | + # NHD -> HND |
| 118 | + q = q.transpose(1, 2) |
| 119 | + k = k.transpose(1, 2) |
| 120 | + v = v.transpose(1, 2) |
| 121 | + |
| 122 | + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) |
| 123 | + return self.out_proj(out) |
| 124 | + |
| 125 | +class MLP(nn.Module): |
| 126 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 127 | + super().__init__() |
| 128 | + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) |
| 129 | + self.activation = F.gelu |
| 130 | + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) |
| 131 | + |
| 132 | + def forward(self, x): |
| 133 | + x = self.fc1(x) |
| 134 | + x = self.activation(x) |
| 135 | + x = self.fc2(x) |
| 136 | + return x |
| 137 | + |
| 138 | +class Block(nn.Module): |
| 139 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 140 | + super().__init__() |
| 141 | + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) |
| 142 | + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) |
| 143 | + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) |
| 144 | + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) |
| 145 | + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) |
| 146 | + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) |
| 147 | + |
| 148 | + def forward(self, hidden_states, mask=None, optimized_attention=None): |
| 149 | + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) |
| 150 | + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) |
| 151 | + mlp_out = self.mlp(hidden_states) |
| 152 | + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) |
| 153 | + return hidden_states |
| 154 | + |
| 155 | +class XLMRobertaEncoder(nn.Module): |
| 156 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 157 | + super().__init__() |
| 158 | + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) |
| 159 | + |
| 160 | + def forward(self, hidden_states, attention_mask=None): |
| 161 | + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) |
| 162 | + for layer in self.layers: |
| 163 | + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) |
| 164 | + return hidden_states |
| 165 | + |
| 166 | +class XLMRobertaModel_(nn.Module): |
| 167 | + def __init__(self, config, device=None, dtype=None, ops=None): |
| 168 | + super().__init__() |
| 169 | + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) |
| 170 | + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) |
| 171 | + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) |
| 172 | + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) |
| 173 | + |
| 174 | + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): |
| 175 | + x = self.embeddings(input_ids=input_ids, embeddings=embeds) |
| 176 | + x = self.emb_ln(x) |
| 177 | + x = self.emb_drop(x) |
| 178 | + |
| 179 | + mask = None |
| 180 | + if attention_mask is not None: |
| 181 | + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) |
| 182 | + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) |
| 183 | + |
| 184 | + sequence_output = self.encoder(x, attention_mask=mask) |
| 185 | + |
| 186 | + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py |
| 187 | + pooled_output = None |
| 188 | + if attention_mask is None: |
| 189 | + pooled_output = sequence_output.mean(dim=1) |
| 190 | + else: |
| 191 | + attention_mask = attention_mask.to(sequence_output.dtype) |
| 192 | + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) |
| 193 | + |
| 194 | + # Intermediate output is not yet implemented, use None for placeholder |
| 195 | + return sequence_output, None, pooled_output |
| 196 | + |
| 197 | +class XLMRobertaModel(nn.Module): |
| 198 | + def __init__(self, config_dict, dtype, device, operations): |
| 199 | + super().__init__() |
| 200 | + self.config = XLMRobertaConfig(**config_dict) |
| 201 | + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) |
| 202 | + self.num_layers = self.config.num_hidden_layers |
| 203 | + |
| 204 | + def get_input_embeddings(self): |
| 205 | + return self.model.embeddings.word_embeddings |
| 206 | + |
| 207 | + def set_input_embeddings(self, embeddings): |
| 208 | + self.model.embeddings.word_embeddings = embeddings |
| 209 | + |
| 210 | + def forward(self, *args, **kwargs): |
| 211 | + return self.model(*args, **kwargs) |
| 212 | + |
| 213 | +class JinaClip2TextModel(sd1_clip.SDClipModel): |
| 214 | + def __init__(self, device="cpu", dtype=None, model_options={}): |
| 215 | + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) |
| 216 | + |
| 217 | +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): |
| 218 | + def __init__(self, device="cpu", dtype=None, model_options={}): |
| 219 | + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) |
0 commit comments