33from transformers import T5TokenizerFast
44from .spiece_tokenizer import SPieceTokenizer
55import comfy .text_encoders .genmo
6- from comfy .ldm .lightricks .embeddings_connector import Embeddings1DConnector
76import torch
87import comfy .utils
98import math
@@ -109,22 +108,6 @@ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={})
109108 operations = self .gemma3_12b .operations # TODO
110109 self .text_embedding_projection = operations .Linear (3840 * 49 , 3840 , bias = False , dtype = dtype , device = device )
111110
112- self .audio_embeddings_connector = Embeddings1DConnector (
113- split_rope = True ,
114- double_precision_rope = True ,
115- dtype = dtype ,
116- device = device ,
117- operations = operations ,
118- )
119-
120- self .video_embeddings_connector = Embeddings1DConnector (
121- split_rope = True ,
122- double_precision_rope = True ,
123- dtype = dtype ,
124- device = device ,
125- operations = operations ,
126- )
127-
128111 def set_clip_options (self , options ):
129112 self .execution_device = options .get ("execution_device" , self .execution_device )
130113 self .gemma3_12b .set_clip_options (options )
@@ -146,10 +129,6 @@ def encode_token_weights(self, token_weight_pairs):
146129 out = out .reshape ((out .shape [0 ], out .shape [1 ], - 1 ))
147130 out = self .text_embedding_projection (out )
148131 out = out .float ()
149- out_vid = self .video_embeddings_connector (out )[0 ]
150- out_audio = self .audio_embeddings_connector (out )[0 ]
151- out = torch .concat ((out_vid , out_audio ), dim = - 1 )
152-
153132 return out .to (out_device ), pooled
154133
155134 def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed ):
@@ -159,14 +138,14 @@ def load_sd(self, sd):
159138 if "model.layers.47.self_attn.q_norm.weight" in sd :
160139 return self .gemma3_12b .load_sd (sd )
161140 else :
162- sdo = comfy .utils .state_dict_prefix_replace (sd , {"text_embedding_projection.aggregate_embed.weight" : "text_embedding_projection.weight" , "model.diffusion_model.video_embeddings_connector." : "video_embeddings_connector." , "model.diffusion_model.audio_embeddings_connector." : "audio_embeddings_connector." }, filter_keys = True )
141+ sdo = comfy .utils .state_dict_prefix_replace (sd , {"text_embedding_projection.aggregate_embed.weight" : "text_embedding_projection.weight" }, filter_keys = True )
163142 if len (sdo ) == 0 :
164143 sdo = sd
165144
166145 missing_all = []
167146 unexpected_all = []
168147
169- for prefix , component in [("text_embedding_projection." , self .text_embedding_projection ), ( "video_embeddings_connector." , self . video_embeddings_connector ), ( "audio_embeddings_connector." , self . audio_embeddings_connector ) ]:
148+ for prefix , component in [("text_embedding_projection." , self .text_embedding_projection )]:
170149 component_sd = {k .replace (prefix , "" ): v for k , v in sdo .items () if k .startswith (prefix )}
171150 if component_sd :
172151 missing , unexpected = component .load_state_dict (component_sd , strict = False , assign = getattr (self , "can_assign_sd" , False ))
0 commit comments