Temporality unbreak some LTXAV (LTX 2.0) workflows to give people time to migrate.#12605
Temporality unbreak some LTXAV (LTX 2.0) workflows to give people time to migrate.#12605comfyanonymous merged 1 commit intomasterfrom
Conversation
📝 WalkthroughWalkthroughThis change introduces a compatibility mode pathway to the LTXAVTEModel class. A new optional feature is implemented via a 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/text_encoders/lt.py`:
- Around line 154-160: The code forces out to float32 then feeds it to
compat-mode connectors, causing dtype mismatches with connector weights (which
may be bf16/fp16); instead of out = out.float(), cast out to the connectors'
dtype before calling self.video_embeddings_connector and
self.audio_embeddings_connector (e.g., determine the connector dtype from a
connector parameter or weight:
next(self.video_embeddings_connector.parameters()).dtype or
self.video_embeddings_connector.weight.dtype) and do out =
out.to(connector_dtype) (or ensure both connectors use the same dtype) so the
forward passes use the connector's dtype and avoid scalar type errors.
| out = out.float() | ||
|
|
||
| if self.compat_mode: | ||
| out_vid = self.video_embeddings_connector(out)[0] | ||
| out_audio = self.audio_embeddings_connector(out)[0] | ||
| out = torch.concat((out_vid, out_audio), dim=-1) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n comfy/text_encoders/lt.py | head -180 | tail -80Repository: Comfy-Org/ComfyUI
Length of output: 4318
🏁 Script executed:
# Search for where the connectors are initialized
rg -n "video_embeddings_connector|audio_embeddings_connector" comfy/text_encoders/lt.py -A 2 -B 2Repository: Comfy-Org/ComfyUI
Length of output: 1895
🏁 Script executed:
# Check the projection's dtype handling
rg -n "projection|dtype" comfy/text_encoders/lt.py | head -20Repository: Comfy-Org/ComfyUI
Length of output: 1793
🏁 Script executed:
# Check git history for line 154 (the float() call)
git log -p --follow comfy/text_encoders/lt.py | grep -A 15 -B 15 "out = out.float()" | head -80Repository: Comfy-Org/ComfyUI
Length of output: 43
🌐 Web query:
PyTorch float32 input bf16 fp16 module weights dtype mismatch behavior without autocast
💡 Result:
Without autocast, PyTorch generally does not automatically cast your activations to match a module’s parameter dtype for “typed” ops like GEMM/linear, matmul, conv, etc. If your input is float32 but the module weights are float16 or bfloat16, you’ll typically get a runtime error like:
RuntimeError: expected scalar type Half but found FloatRuntimeError: expected scalar type Float but found BFloat16RuntimeError: mat1 and mat2 must have the same dtype(common foraddmm/matmul/linear) [1], [2]
This differs from many elementwise arithmetic ops (add/sub/mul/div), which follow PyTorch type promotion rules and may upcast instead of erroring [3].
What autocast changes
Inside torch.autocast(...), PyTorch will automatically cast inputs for certain eligible ops (e.g., mm) so you can feed float32 tensors and still run the op in float16/bfloat16, producing lower-precision outputs without manual casts [4]. PyTorch also warns that tensors produced inside autocast regions (often float16) can later cause dtype mismatch errors when used outside the region with tensors of different dtype—so you may need to cast outputs back to float32 when leaving autocast [5].
Practical fixes (no autocast)
- Make dtypes match explicitly:
x = x.to(weight.dtype)(or keep the module infloat32). - Or use
torch.autocast(device_type="cuda", dtype=torch.float16|torch.bfloat16)around the forward where you want mixed precision [4], [5].
Sources: [1] [2] [3] [4] [5]
Fix dtype mismatch between output and compat-mode connectors.
At line 154, out is forced to float32, but the connectors (initialized at lines 117, 125) use the projection's dtype, which is often bf16 or fp16. Passing float32 tensors to modules with bf16/fp16 weights will raise a runtime error like expected scalar type Half but found Float.
Cast to connector dtype before the forward pass, then the result stays in the connector's dtype:
Proposed fix
- out = out.float()
-
- if self.compat_mode:
- out_vid = self.video_embeddings_connector(out)[0]
- out_audio = self.audio_embeddings_connector(out)[0]
- out = torch.concat((out_vid, out_audio), dim=-1)
+ if self.compat_mode:
+ compat_dtype = next(self.video_embeddings_connector.parameters()).dtype
+ out_compat = out.to(dtype=compat_dtype)
+ out_vid = self.video_embeddings_connector(out_compat)[0]
+ out_audio = self.audio_embeddings_connector(out_compat)[0]
+ out = torch.concat((out_vid, out_audio), dim=-1)
+ else:
+ out = out.float()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| out = out.float() | |
| if self.compat_mode: | |
| out_vid = self.video_embeddings_connector(out)[0] | |
| out_audio = self.audio_embeddings_connector(out)[0] | |
| out = torch.concat((out_vid, out_audio), dim=-1) | |
| if self.compat_mode: | |
| compat_dtype = next(self.video_embeddings_connector.parameters()).dtype | |
| out_compat = out.to(dtype=compat_dtype) | |
| out_vid = self.video_embeddings_connector(out_compat)[0] | |
| out_audio = self.audio_embeddings_connector(out_compat)[0] | |
| out = torch.concat((out_vid, out_audio), dim=-1) | |
| else: | |
| out = out.float() |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@comfy/text_encoders/lt.py` around lines 154 - 160, The code forces out to
float32 then feeds it to compat-mode connectors, causing dtype mismatches with
connector weights (which may be bf16/fp16); instead of out = out.float(), cast
out to the connectors' dtype before calling self.video_embeddings_connector and
self.audio_embeddings_connector (e.g., determine the connector dtype from a
connector parameter or weight:
next(self.video_embeddings_connector.parameters()).dtype or
self.video_embeddings_connector.weight.dtype) and do out =
out.to(connector_dtype) (or ensure both connectors use the same dtype) so the
forward passes use the connector's dtype and avoid scalar type errors.
This will eventually be removed again which will break many workflows that don't use the official LTXAV (LTX 2.0) files.
If you use the official LTXV files you are good. If you use non official files please migrate.