speed up and reduce VRAM of QWEN VAE and WAN (less so)#12036
Merged
comfyanonymous merged 3 commits intoComfy-Org:masterfrom Jan 24, 2026
Merged
speed up and reduce VRAM of QWEN VAE and WAN (less so)#12036comfyanonymous merged 3 commits intoComfy-Org:masterfrom
comfyanonymous merged 3 commits intoComfy-Org:masterfrom
Conversation
This works around pytorch missing ability to causal pad as part of the kernel and avoids massive weight duplications for padding.
This currently uses F.pad which takes a full deep copy and is liable to be the VRAM peak. Instead, kick spatial padding back to the op and consolidate the temporal padding with the cat for the cache.
The WAN VAE is also QWEN where it is used single-image. These convolutions are however zero padded 3d convolutions, which means the VAE is actually just 2D down the last element of the conv weight in the temporal dimension. Fast path this, to avoid adding zeros that then just evaporate in convoluton math but cost computation.
721138c to
ea9d6fa
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is a fast path for QWEN VAE by taking advantage of the fact that zero padding a 3d convolution on a single frame is mathematically equivalent to just slicing the weight.
I dont have a good explanation for why it is this much faster. I was expecting VRAM savings and hoping for some speed, but I got less VRAM and a ton of speedup. pytorch works in mysterious ways.
Example test conditions:
RTX5090
QWEN VAE Encode -> Decode (3840x2160)
Before:
After:
WAN2.2 VAE Encode -> Decode (1920x1088x81f)
Before (31.8GB):
After(30.8GB):
This WAN test point is probably thrashing the pytorch and cuda-malloc-async allocators to stay on the VRAM ceiling which is why its more of a speedup than a VRAM saving.