-
Notifications
You must be signed in to change notification settings - Fork 237
Layer offloading #966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Layer offloading #966
Conversation
Summary of ChangesHello @Ohm-Rishabh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance optimization for DIT models by implementing layer-wise offloading. This mechanism intelligently moves model layers between the CPU and GPU, prefetching upcoming layers to minimize latency during model execution, particularly when swapping transformers. The primary goal is to drastically reduce the overhead associated with such operations, leading to faster end-to-end runtime by efficiently managing memory and data transfer. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a layer-wise offloading feature, which significantly reduces overhead when swapping transformers, leading to a noticeable performance improvement. The implementation is centered around a new LayerwiseOffloadManager that handles asynchronous prefetching and releasing of model layers. The changes are well-structured. My review focuses on improving code quality and maintainability by reducing code duplication and encouraging more specific exception handling for robustness.
| offload_mgr = getattr(self, "_layerwise_offload_manager", None) | ||
| if offload_mgr is not None and getattr(offload_mgr, "enabled", False): | ||
| for i, block in enumerate(self.blocks): | ||
| with offload_mgr.layer_scope( | ||
| prefetch_layer_idx=i+1 if i+1 < len(self.blocks) else None, | ||
| release_layer_idx=i, | ||
| non_blocking=True, | ||
| ): | ||
| hidden_states = block(hidden_states, encoder_hidden_states, | ||
| timestep_proj, freqs_cis, attention_mask) | ||
| else: | ||
| for block in self.blocks: | ||
| hidden_states = block(hidden_states, encoder_hidden_states, | ||
| timestep_proj, freqs_cis, attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else block contains duplicated code for iterating through transformer blocks. You can refactor this to be more concise and maintainable by using contextlib.nullcontext for the case where offloading is disabled. This avoids repeating the loop body.
First, add from contextlib import nullcontext at the top of the file.
Then, you can replace this entire block with:
offload_mgr = getattr(self, "_layerwise_offload_manager", None)
use_offload = offload_mgr is not None and getattr(offload_mgr, "enabled", False)
for i, block in enumerate(self.blocks):
scope = offload_mgr.layer_scope(
prefetch_layer_idx=i + 1 if i + 1 < len(self.blocks) else None,
release_layer_idx=i,
non_blocking=True,
) if use_offload else nullcontext()
with scope:
hidden_states = block(hidden_states, encoder_hidden_states,
timestep_proj, freqs_cis, attention_mask)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use nullcontext to avoid code duplication
| try: | ||
| return int(m.group(2)) | ||
| except Exception: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad except Exception: can hide other potential issues. It's better to catch a more specific exception. In this case, int() will raise a ValueError if the conversion fails.
| try: | |
| return int(m.group(2)) | |
| except Exception: | |
| return None | |
| try: | |
| return int(m.group(2)) | |
| except ValueError: | |
| return None |
|
|
||
| def _record_meta(self, name: str, t: torch.Tensor) -> None: | ||
| if name not in self._meta: | ||
| self._meta[name] = (int(t.ndim), t.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for layer_idx in list(self._gpu_layers.keys()): | ||
| param_names = self._gpu_layers.pop(layer_idx, None) | ||
| if not param_names: | ||
| continue | ||
| for name in param_names: | ||
| target = self._get_target(name) | ||
| self._record_meta(name, target) | ||
| target.data = self._make_placeholder(name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic inside this loop is a duplication of the release_layer method. To improve maintainability and reduce code duplication, you can refactor this to call self.release_layer() inside the loop.
| for layer_idx in list(self._gpu_layers.keys()): | |
| param_names = self._gpu_layers.pop(layer_idx, None) | |
| if not param_names: | |
| continue | |
| for name in param_names: | |
| target = self._get_target(name) | |
| self._record_meta(name, target) | |
| target.data = self._make_placeholder(name) | |
| for layer_idx in list(self._gpu_layers.keys()): | |
| self.release_layer(layer_idx) |
| try: | ||
| num_layers = len(getattr(model, "blocks")) | ||
| except Exception: | ||
| num_layers = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # This ensures non-managed parameters (embeddings, final norms) are on GPU | ||
| model = model.to(get_local_torch_device()) | ||
|
|
||
| from fastvideo.models.layerwise_offload import LayerwiseOffloadManager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont use local imports
| transformer_device = next( | ||
| current_model.parameters()).device.type | ||
| if transformer_device == 'cpu': | ||
| current_model.to(get_local_torch_device()) | ||
| # Sync to ensure onloading completes | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| current_guidance_scale = batch.guidance_scale | ||
| else: | ||
| # low-noise stage in wan2.2 | ||
| if fastvideo_args.dit_cpu_offload and next( | ||
| self.transformer.parameters( | ||
| )).device.type == 'cuda': | ||
| if (fastvideo_args.dit_cpu_offload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.transformer_2 is not None: | ||
| mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", | ||
| None) | ||
| if mgr2 is not None and getattr(mgr2, "enabled", False): | ||
| mgr2.release_all() | ||
|
|
||
| # Save STA mask search results if needed | ||
| if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING: | ||
| self.save_sta_search_results(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed
| @@ -0,0 +1,217 @@ | |||
| import re | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at the top, could you add a citation to SGLang's implementation,
Also make new files has the apache 2.0 license header
| offload_mgr = getattr(self, "_layerwise_offload_manager", None) | ||
| if offload_mgr is not None and getattr(offload_mgr, "enabled", False): | ||
| for i, block in enumerate(self.blocks): | ||
| with offload_mgr.layer_scope( | ||
| prefetch_layer_idx=i+1 if i+1 < len(self.blocks) else None, | ||
| release_layer_idx=i, | ||
| non_blocking=True, | ||
| ): | ||
| hidden_states = block(hidden_states, encoder_hidden_states, | ||
| timestep_proj, freqs_cis, attention_mask) | ||
| else: | ||
| for block in self.blocks: | ||
| hidden_states = block(hidden_states, encoder_hidden_states, | ||
| timestep_proj, freqs_cis, attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use nullcontext to avoid code duplication
|
|
||
| model = model.eval() | ||
|
|
||
| if fastvideo_args.dit_layerwise_offload and hasattr(model, "blocks"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently only wan supports layerwise offload right? Could you emit a warning and turn off layerwise for unsupported models?
| # This ensures non-managed parameters (embeddings, final norms) are on GPU | ||
| model = model.to(get_local_torch_device()) | ||
|
|
||
| from fastvideo.models.layerwise_offload import LayerwiseOffloadManager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont use local imports
cd4ce11 to
e979dcb
Compare
This feature implements layer-wise offloading where we asynchronously offload layers and prefetch the next layer as well using the CPU-GPU communication channel.
without layerwise offloading. when we swap transformers during denoising (high to low noise) for wan2.2 we create an overhead of ~25s.
end-to-end run time: 457.98s
with this feature now that we are streaming the layers, the overhead is neglibable at ~2s at swap boundary.
end-to-end run time: 424.46s
ran all experiments for wan2.2 for height=720, width=1280, num_frames=81, sp=4, fsdp disabled