Skip to content

Conversation

@imperatormk
Copy link

@imperatormk imperatormk commented Dec 13, 2025

Problem

Multi-GPU inference crashes when sample_steps > num_gpus_dit:

  • --sample_steps 8 with --num_gpus_dit 4 causes KeyError: 'cond_shape' or KeyError: '5'

Root cause: The denoising loop maps step i directly to rank i (if i != dist.get_rank(): continue), so steps 4+ have no GPU to run them.

Solution

Implement modulo cycling so steps wrap around available GPUs:

  • Steps 0,4,8,12 → Rank 0
  • Steps 1,5,9,13 → Rank 1
  • Steps 2,6,10,14 → Rank 2
  • Steps 3,7,11,15 → Rank 3

Changes:

  1. Modulo cycling: if i % num_gpus_dit != my_rank instead of if i != dist.get_rank()
  2. Dynamic send/recv: Last rank in cycle sends to rank 0 (not VAE), final step sends to VAE
  3. Scheduler fix: Set sample_scheduler._step_index = i for correct timestep handling
  4. kv_cache init: Initialize for all ranks including VAE to avoid KeyError
  5. KV cache sync: Broadcast KV cache from the GPU that completed each step to all other DiT GPUs

Why KV cache sync is needed

With modulo cycling, each GPU maintains its own KV cache. The denoising loop processes the same temporal block across all steps, with each step overwriting the same cache positions:

Step GPU GPU's KV Cache
0 GPU 0 writes [X:X+len]
1 GPU 1 overwrites [X:X+len]
2 GPU 2 overwrites [X:X+len]
3 GPU 3 overwrites [X:X+len]
4 GPU 0 has stale values from step 0!

Without sync, GPU 0 at step 4 reads stale KV values from step 0 instead of step 3's values. This breaks temporal attention alignment, causing lipsync drift.

The fix broadcasts the updated KV cache after each step using dist.send/recv.

Testing

Tested on 5 GPUs (4 DiT + 1 VAE) with:

  • --sample_steps 8
  • --sample_steps 16

Video quality improved with more denoising steps as expected.

Previously, multi-GPU inference crashed with KeyError when sample_steps
exceeded num_gpus_dit (e.g., 8 steps with 4 DiT GPUs).

The issue was 1:1 mapping of step i to rank i, leaving steps 4+ with no GPU.

Fix:
- Modulo cycling: step i runs on rank (i % num_gpus_dit)
- Dynamic send/recv: cycle wraps to rank 0, final step sends to VAE
- Scheduler index fix: set _step_index = i for correct timestep

Tested with 8 and 16 steps on 4 DiT GPUs + 1 VAE GPU.
@imperatormk imperatormk force-pushed the fix/modulo-cycling-multi-gpu branch from f62c1e7 to 5714523 Compare December 13, 2025 23:44
Fixes lipsync desync when using sample_steps > num_gpus_dit.

With modulo cycling, each GPU maintains its own KV cache. The denoising
loop processes the same temporal block across all steps, with each step
overwriting the same cache positions. When GPU 0 processes step 4, it
still has stale KV values from step 0 instead of step 3's values.

This commit adds _sync_kv_cache() which broadcasts the updated cache
from the GPU that just completed a step to all other DiT GPUs.
@Yubo-Shankui
Copy link
Collaborator

Thanks a lot for the detailed investigation and the fix — enabling sample_steps > num_gpus_dit is definitely a meaningful and useful improvement.

I tried merging your changes locally, but in my setup multi-GPU inference consistently hangs at the communication between step 1 → step 2 (progress stalls after finishing step 1). I’m still debugging whether this is a corner case in the current send/recv logic or something specific to my environment.

One additional concern I have is that synchronizing the full KV cache across all DiT GPUs at every denoising step could introduce substantial communication overhead, which may significantly reduce multi-GPU parallel efficiency. While correctness-wise the fix makes sense, the performance trade-off might be non-trivial.

That said, the flexibility to decouple sample_steps from num_gpus_dit is very valuable. An alternative direction could be to reserve multiple KV caches per GPU, or to adapt the pipeline to more GPUs (e.g., 8+1 for 8-step denoising) to avoid frequent cross-GPU KV synchronization.

Thanks again for the contribution — I think this is an important direction to explore further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants