Skip to content

speed up and reduce VRAM of QWEN VAE and WAN (less so)#12036

Merged
comfyanonymous merged 3 commits intoComfy-Org:masterfrom
rattus128:prs/qwen-vae-2d
Jan 24, 2026
Merged

speed up and reduce VRAM of QWEN VAE and WAN (less so)#12036
comfyanonymous merged 3 commits intoComfy-Org:masterfrom
rattus128:prs/qwen-vae-2d

Conversation

@rattus128
Copy link
Contributor

This is a fast path for QWEN VAE by taking advantage of the fact that zero padding a 3d convolution on a single frame is mathematically equivalent to just slicing the weight.

I dont have a good explanation for why it is this much faster. I was expecting VRAM savings and hoping for some speed, but I got less VRAM and a ton of speedup. pytorch works in mysterious ways.

Example test conditions:

RTX5090

QWEN VAE Encode -> Decode (3840x2160)

Before:

image
Requested to load WanVAE
loaded completely; 7324.16 MB usable, 242.03 MB loaded, full load: True
0 models unloaded.
Unloaded partially: 242.03 MB freed, 0.00 MB remains loaded, 22.78 MB buffer reserved, lowvram patches: 0
Prompt executed in 36.82 seconds

After:

image
Requested to load WanVAE
loaded completely; 7324.16 MB usable, 242.03 MB loaded, full load: True
0 models unloaded.
Unloaded partially: 242.03 MB freed, 0.00 MB remains loaded, 22.78 MB buffer reserved, lowvram patches: 0
Prompt executed in 2.42 seconds

WAN2.2 VAE Encode -> Decode (1920x1088x81f)

Before (31.8GB):

image
Requested to load WanVAE
loaded completely; 7148.38 MB usable, 242.03 MB loaded, full load: True
Prompt executed in 51.00 seconds

After(30.8GB):

image
Requested to load WanVAE
loaded completely; 7148.38 MB usable, 242.03 MB loaded, full load: True
Prompt executed in 41.56 seconds

This WAN test point is probably thrashing the pytorch and cuda-malloc-async allocators to stay on the VRAM ceiling which is why its more of a speedup than a VRAM saving.

@rattus128 rattus128 marked this pull request as draft January 23, 2026 02:47
This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.
This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.
The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.
@rattus128 rattus128 marked this pull request as ready for review January 23, 2026 07:41
@comfyanonymous comfyanonymous merged commit 4e6a1b6 into Comfy-Org:master Jan 24, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants