@@ -101,13 +101,36 @@ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={})
101101 super ().__init__ ()
102102 self .dtypes = set ()
103103 self .dtypes .add (dtype )
104+ self .compat_mode = False
104105
105106 self .gemma3_12b = Gemma3_12BModel (device = device , dtype = dtype_llama , model_options = model_options , layer = "all" , layer_idx = None )
106107 self .dtypes .add (dtype_llama )
107108
108109 operations = self .gemma3_12b .operations # TODO
109110 self .text_embedding_projection = operations .Linear (3840 * 49 , 3840 , bias = False , dtype = dtype , device = device )
110111
112+ def enable_compat_mode (self ): # TODO: remove
113+ from comfy .ldm .lightricks .embeddings_connector import Embeddings1DConnector
114+ operations = self .gemma3_12b .operations
115+ dtype = self .text_embedding_projection .weight .dtype
116+ device = self .text_embedding_projection .weight .device
117+ self .audio_embeddings_connector = Embeddings1DConnector (
118+ split_rope = True ,
119+ double_precision_rope = True ,
120+ dtype = dtype ,
121+ device = device ,
122+ operations = operations ,
123+ )
124+
125+ self .video_embeddings_connector = Embeddings1DConnector (
126+ split_rope = True ,
127+ double_precision_rope = True ,
128+ dtype = dtype ,
129+ device = device ,
130+ operations = operations ,
131+ )
132+ self .compat_mode = True
133+
111134 def set_clip_options (self , options ):
112135 self .execution_device = options .get ("execution_device" , self .execution_device )
113136 self .gemma3_12b .set_clip_options (options )
@@ -129,6 +152,12 @@ def encode_token_weights(self, token_weight_pairs):
129152 out = out .reshape ((out .shape [0 ], out .shape [1 ], - 1 ))
130153 out = self .text_embedding_projection (out )
131154 out = out .float ()
155+
156+ if self .compat_mode :
157+ out_vid = self .video_embeddings_connector (out )[0 ]
158+ out_audio = self .audio_embeddings_connector (out )[0 ]
159+ out = torch .concat ((out_vid , out_audio ), dim = - 1 )
160+
132161 return out .to (out_device ), pooled
133162
134163 def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed ):
@@ -152,6 +181,16 @@ def load_sd(self, sd):
152181 missing_all .extend ([f"{ prefix } { k } " for k in missing ])
153182 unexpected_all .extend ([f"{ prefix } { k } " for k in unexpected ])
154183
184+ if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd : # TODO: remove
185+ ww = sd .get ("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias" , None )
186+ if ww is not None :
187+ if ww .shape [0 ] == 3840 :
188+ self .enable_compat_mode ()
189+ sdv = comfy .utils .state_dict_prefix_replace (sd , {"model.diffusion_model.video_embeddings_connector." : "" }, filter_keys = True )
190+ self .video_embeddings_connector .load_state_dict (sdv , strict = False , assign = getattr (self , "can_assign_sd" , False ))
191+ sda = comfy .utils .state_dict_prefix_replace (sd , {"model.diffusion_model.audio_embeddings_connector." : "" }, filter_keys = True )
192+ self .audio_embeddings_connector .load_state_dict (sda , strict = False , assign = getattr (self , "can_assign_sd" , False ))
193+
155194 return (missing_all , unexpected_all )
156195
157196 def memory_estimation_function (self , token_weight_pairs , device = None ):
0 commit comments