Skip to content

Commit 84aba95

Browse files
Temporality unbreak some LTXAV workflows to give people time to migrate. (Comfy-Org#12605)
1 parent 9b1c63e commit 84aba95

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

comfy/text_encoders/lt.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,36 @@ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={})
101101
super().__init__()
102102
self.dtypes = set()
103103
self.dtypes.add(dtype)
104+
self.compat_mode = False
104105

105106
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
106107
self.dtypes.add(dtype_llama)
107108

108109
operations = self.gemma3_12b.operations # TODO
109110
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
110111

112+
def enable_compat_mode(self): # TODO: remove
113+
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
114+
operations = self.gemma3_12b.operations
115+
dtype = self.text_embedding_projection.weight.dtype
116+
device = self.text_embedding_projection.weight.device
117+
self.audio_embeddings_connector = Embeddings1DConnector(
118+
split_rope=True,
119+
double_precision_rope=True,
120+
dtype=dtype,
121+
device=device,
122+
operations=operations,
123+
)
124+
125+
self.video_embeddings_connector = Embeddings1DConnector(
126+
split_rope=True,
127+
double_precision_rope=True,
128+
dtype=dtype,
129+
device=device,
130+
operations=operations,
131+
)
132+
self.compat_mode = True
133+
111134
def set_clip_options(self, options):
112135
self.execution_device = options.get("execution_device", self.execution_device)
113136
self.gemma3_12b.set_clip_options(options)
@@ -129,6 +152,12 @@ def encode_token_weights(self, token_weight_pairs):
129152
out = out.reshape((out.shape[0], out.shape[1], -1))
130153
out = self.text_embedding_projection(out)
131154
out = out.float()
155+
156+
if self.compat_mode:
157+
out_vid = self.video_embeddings_connector(out)[0]
158+
out_audio = self.audio_embeddings_connector(out)[0]
159+
out = torch.concat((out_vid, out_audio), dim=-1)
160+
132161
return out.to(out_device), pooled
133162

134163
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
@@ -152,6 +181,16 @@ def load_sd(self, sd):
152181
missing_all.extend([f"{prefix}{k}" for k in missing])
153182
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
154183

184+
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
185+
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
186+
if ww is not None:
187+
if ww.shape[0] == 3840:
188+
self.enable_compat_mode()
189+
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
190+
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
191+
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
192+
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
193+
155194
return (missing_all, unexpected_all)
156195

157196
def memory_estimation_function(self, token_weight_pairs, device=None):

0 commit comments

Comments
 (0)