Skip to content

Conversation

@Ohm-Rishabh
Copy link
Contributor

@Ohm-Rishabh Ohm-Rishabh commented Dec 27, 2025

This feature implements layer-wise offloading where we asynchronously offload layers and prefetch the next layer as well using the CPU-GPU communication channel.

without layerwise offloading. when we swap transformers during denoising (high to low noise) for wan2.2 we create an overhead of ~25s.
end-to-end run time: 457.98s

with this feature now that we are streaming the layers, the overhead is neglibable at ~2s at swap boundary.
end-to-end run time: 424.46s

ran all experiments for wan2.2 for height=720, width=1280, num_frames=81, sp=4, fsdp disabled

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Ohm-Rishabh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization for DIT models by implementing layer-wise offloading. This mechanism intelligently moves model layers between the CPU and GPU, prefetching upcoming layers to minimize latency during model execution, particularly when swapping transformers. The primary goal is to drastically reduce the overhead associated with such operations, leading to faster end-to-end runtime by efficiently managing memory and data transfer.

Highlights

  • Performance Optimization: Introduced layer-wise offloading for DIT models, significantly reducing overhead during transformer swaps from approximately 25 seconds to just 2 seconds, leading to faster end-to-end runtime.
  • Asynchronous Offloading: Implemented an asynchronous mechanism to offload model layers to the CPU and prefetch subsequent layers to the GPU, leveraging the CPU-GPU communication channel for efficiency.
  • New LayerwiseOffloadManager: Added a dedicated LayerwiseOffloadManager class to handle the complex logic of layer offloading, prefetching, and releasing, ensuring optimal memory management and performance.
  • Configuration and Integration: A new CLI argument --dit-layerwise-offload has been added to enable this feature. The offloading logic is integrated into the wanvideo.py forward pass and the denoising.py stage, with automatic disabling of use_fsdp_inference and dit_cpu_offload when layer-wise offloading is active.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a layer-wise offloading feature, which significantly reduces overhead when swapping transformers, leading to a noticeable performance improvement. The implementation is centered around a new LayerwiseOffloadManager that handles asynchronous prefetching and releasing of model layers. The changes are well-structured. My review focuses on improving code quality and maintainability by reducing code duplication and encouraging more specific exception handling for robustness.

Comment on lines 736 to 749
offload_mgr = getattr(self, "_layerwise_offload_manager", None)
if offload_mgr is not None and getattr(offload_mgr, "enabled", False):
for i, block in enumerate(self.blocks):
with offload_mgr.layer_scope(
prefetch_layer_idx=i+1 if i+1 < len(self.blocks) else None,
release_layer_idx=i,
non_blocking=True,
):
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis, attention_mask)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis, attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block contains duplicated code for iterating through transformer blocks. You can refactor this to be more concise and maintainable by using contextlib.nullcontext for the case where offloading is disabled. This avoids repeating the loop body.

First, add from contextlib import nullcontext at the top of the file.

Then, you can replace this entire block with:

offload_mgr = getattr(self, "_layerwise_offload_manager", None)
use_offload = offload_mgr is not None and getattr(offload_mgr, "enabled", False)

for i, block in enumerate(self.blocks):
    scope = offload_mgr.layer_scope(
        prefetch_layer_idx=i + 1 if i + 1 < len(self.blocks) else None,
        release_layer_idx=i,
        non_blocking=True,
    ) if use_offload else nullcontext()
    
    with scope:
        hidden_states = block(hidden_states, encoder_hidden_states,
                              timestep_proj, freqs_cis, attention_mask)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use nullcontext to avoid code duplication

Comment on lines +57 to +67
try:
return int(m.group(2))
except Exception:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide other potential issues. It's better to catch a more specific exception. In this case, int() will raise a ValueError if the conversion fails.

Suggested change
try:
return int(m.group(2))
except Exception:
return None
try:
return int(m.group(2))
except ValueError:
return None


def _record_meta(self, name: str, t: torch.Tensor) -> None:
if name not in self._meta:
self._meta[name] = (int(t.ndim), t.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast around t.ndim is redundant, as torch.Tensor.ndim already returns an integer. Removing it makes the code slightly cleaner.

Suggested change
self._meta[name] = (int(t.ndim), t.dtype)
self._meta[name] = (t.ndim, t.dtype)

Comment on lines +210 to +224
for layer_idx in list(self._gpu_layers.keys()):
param_names = self._gpu_layers.pop(layer_idx, None)
if not param_names:
continue
for name in param_names:
target = self._get_target(name)
self._record_meta(name, target)
target.data = self._make_placeholder(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic inside this loop is a duplication of the release_layer method. To improve maintainability and reduce code duplication, you can refactor this to call self.release_layer() inside the loop.

Suggested change
for layer_idx in list(self._gpu_layers.keys()):
param_names = self._gpu_layers.pop(layer_idx, None)
if not param_names:
continue
for name in param_names:
target = self._get_target(name)
self._record_meta(name, target)
target.data = self._make_placeholder(name)
for layer_idx in list(self._gpu_layers.keys()):
self.release_layer(layer_idx)

Comment on lines 512 to 515
try:
num_layers = len(getattr(model, "blocks"))
except Exception:
num_layers = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can mask unexpected errors. It's better to catch more specific exceptions. In this case, len() on a non-sequence object would raise a TypeError.

            except TypeError:
                num_layers = None

# This ensures non-managed parameters (embeddings, final norms) are on GPU
model = model.to(get_local_torch_device())

from fastvideo.models.layerwise_offload import LayerwiseOffloadManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Local imports, like this one for LayerwiseOffloadManager, can lead to circular dependency issues and make code harder to read and maintain. It's a best practice to move all imports to the top of the file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont use local imports

Comment on lines 266 to 273
transformer_device = next(
current_model.parameters()).device.type
if transformer_device == 'cpu':
current_model.to(get_local_torch_device())
# Sync to ensure onloading completes
if torch.cuda.is_available():
torch.cuda.synchronize()
current_guidance_scale = batch.guidance_scale
else:
# low-noise stage in wan2.2
if fastvideo_args.dit_cpu_offload and next(
self.transformer.parameters(
)).device.type == 'cuda':
if (fastvideo_args.dit_cpu_offload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code to ensure the model is loaded onto the GPU is duplicated for the else case below (lines 286-296). To improve code maintainability and reduce duplication, consider extracting this logic into a helper method within the class.

Comment on lines +490 to 492
if self.transformer_2 is not None:
mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager",
None)
if mgr2 is not None and getattr(mgr2, "enabled", False):
mgr2.release_all()

# Save STA mask search results if needed
if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING:
self.save_sta_search_results(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for releasing offloaded layers is duplicated for self.transformer and self.transformer_2. To make the code more concise and maintainable, you could extract this into a small helper function.

@SolitaryThinker SolitaryThinker added the go Trigger Buildkite CI label Dec 27, 2025
Comment on lines 291 to 292
if torch.cuda.is_available():
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

Comment on lines 271 to 272
if torch.cuda.is_available():
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

@@ -0,0 +1,217 @@
import re
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at the top, could you add a citation to SGLang's implementation,

Also make new files has the apache 2.0 license header

Comment on lines 736 to 749
offload_mgr = getattr(self, "_layerwise_offload_manager", None)
if offload_mgr is not None and getattr(offload_mgr, "enabled", False):
for i, block in enumerate(self.blocks):
with offload_mgr.layer_scope(
prefetch_layer_idx=i+1 if i+1 < len(self.blocks) else None,
release_layer_idx=i,
non_blocking=True,
):
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis, attention_mask)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis, attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use nullcontext to avoid code duplication


model = model.eval()

if fastvideo_args.dit_layerwise_offload and hasattr(model, "blocks"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently only wan supports layerwise offload right? Could you emit a warning and turn off layerwise for unsupported models?

# This ensures non-managed parameters (embeddings, final norms) are on GPU
model = model.to(get_local_torch_device())

from fastvideo.models.layerwise_offload import LayerwiseOffloadManager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont use local imports

@SolitaryThinker SolitaryThinker merged commit d83f45a into hao-ai-lab:main Jan 4, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go Trigger Buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants