fix: enable sample_steps > num_gpus_dit with modulo cycling #12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
Multi-GPU inference crashes when
sample_steps > num_gpus_dit:--sample_steps 8with--num_gpus_dit 4causesKeyError: 'cond_shape'orKeyError: '5'Root cause: The denoising loop maps step
idirectly to ranki(if i != dist.get_rank(): continue), so steps 4+ have no GPU to run them.Solution
Implement modulo cycling so steps wrap around available GPUs:
Changes:
if i % num_gpus_dit != my_rankinstead ofif i != dist.get_rank()sample_scheduler._step_index = ifor correct timestep handlingWhy KV cache sync is needed
With modulo cycling, each GPU maintains its own KV cache. The denoising loop processes the same temporal block across all steps, with each step overwriting the same cache positions:
Without sync, GPU 0 at step 4 reads stale KV values from step 0 instead of step 3's values. This breaks temporal attention alignment, causing lipsync drift.
The fix broadcasts the updated KV cache after each step using
dist.send/recv.Testing
Tested on 5 GPUs (4 DiT + 1 VAE) with:
--sample_steps 8--sample_steps 16Video quality improved with more denoising steps as expected.