Skip to content

Commit f266b8d

Browse files
Move LTXAV av embedding connectors to diffusion model. (Comfy-Org#12569)
1 parent b6cb30b commit f266b8d

File tree

3 files changed

+30
-23
lines changed

3 files changed

+30
-23
lines changed

comfy/ldm/lightricks/av_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
LTXVModel,
1010
)
1111
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
12+
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
1213
import comfy.ldm.common_dit
1314

1415
class CompressedTimestep:
@@ -450,6 +451,29 @@ def _init_model_components(self, device, dtype, **kwargs):
450451
operations=self.operations,
451452
)
452453

454+
self.audio_embeddings_connector = Embeddings1DConnector(
455+
split_rope=True,
456+
double_precision_rope=True,
457+
dtype=dtype,
458+
device=device,
459+
operations=self.operations,
460+
)
461+
462+
self.video_embeddings_connector = Embeddings1DConnector(
463+
split_rope=True,
464+
double_precision_rope=True,
465+
dtype=dtype,
466+
device=device,
467+
operations=self.operations,
468+
)
469+
470+
def preprocess_text_embeds(self, context):
471+
if context.shape[-1] == self.caption_channels * 2:
472+
return context
473+
out_vid = self.video_embeddings_connector(context)[0]
474+
out_audio = self.audio_embeddings_connector(context)[0]
475+
return torch.concat((out_vid, out_audio), dim=-1)
476+
453477
def _init_transformer_blocks(self, device, dtype, **kwargs):
454478
"""Initialize transformer blocks for LTXAV."""
455479
self.transformer_blocks = nn.ModuleList(

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,10 +988,14 @@ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
988988
def extra_conds(self, **kwargs):
989989
out = super().extra_conds(**kwargs)
990990
attention_mask = kwargs.get("attention_mask", None)
991+
device = kwargs["device"]
992+
991993
if attention_mask is not None:
992994
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
993995
cross_attn = kwargs.get("cross_attn", None)
994996
if cross_attn is not None:
997+
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
998+
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
995999
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
9961000

9971001
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))

comfy/text_encoders/lt.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from transformers import T5TokenizerFast
44
from .spiece_tokenizer import SPieceTokenizer
55
import comfy.text_encoders.genmo
6-
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
76
import torch
87
import comfy.utils
98
import math
@@ -109,22 +108,6 @@ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={})
109108
operations = self.gemma3_12b.operations # TODO
110109
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
111110

112-
self.audio_embeddings_connector = Embeddings1DConnector(
113-
split_rope=True,
114-
double_precision_rope=True,
115-
dtype=dtype,
116-
device=device,
117-
operations=operations,
118-
)
119-
120-
self.video_embeddings_connector = Embeddings1DConnector(
121-
split_rope=True,
122-
double_precision_rope=True,
123-
dtype=dtype,
124-
device=device,
125-
operations=operations,
126-
)
127-
128111
def set_clip_options(self, options):
129112
self.execution_device = options.get("execution_device", self.execution_device)
130113
self.gemma3_12b.set_clip_options(options)
@@ -146,10 +129,6 @@ def encode_token_weights(self, token_weight_pairs):
146129
out = out.reshape((out.shape[0], out.shape[1], -1))
147130
out = self.text_embedding_projection(out)
148131
out = out.float()
149-
out_vid = self.video_embeddings_connector(out)[0]
150-
out_audio = self.audio_embeddings_connector(out)[0]
151-
out = torch.concat((out_vid, out_audio), dim=-1)
152-
153132
return out.to(out_device), pooled
154133

155134
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
@@ -159,14 +138,14 @@ def load_sd(self, sd):
159138
if "model.layers.47.self_attn.q_norm.weight" in sd:
160139
return self.gemma3_12b.load_sd(sd)
161140
else:
162-
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
141+
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
163142
if len(sdo) == 0:
164143
sdo = sd
165144

166145
missing_all = []
167146
unexpected_all = []
168147

169-
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
148+
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
170149
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
171150
if component_sd:
172151
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))

0 commit comments

Comments
 (0)