Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions diffsynth_engine/utils/offload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from typing import Dict

import platform

def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"):
module = module.to("cpu")
Expand All @@ -26,13 +26,14 @@ def _forward_pre_hook(module: nn.Module, input_):
for name, buffer in module.named_buffers(recurse=recurse):
buffer.data = buffer.data.to(device=device)
return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input_)

for name, param in module.named_parameters(recurse=recurse):
param.data = param.data.pin_memory()
for name, param in module.named_parameters(recurse=recurse):
if platform.system() == 'Linux':
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
param.data = param.data.to(device=device)
for name, buffer in module.named_buffers(recurse=recurse):
buffer.data = buffer.data.pin_memory()
if platform.system() == 'Linux':
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data
buffer.data = buffer.data.to(device=device)
setattr(module, "_offload_param_dict", offload_param_dict)
Comment on lines +29 to 39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling platform.system() inside loops is inefficient as the result will not change during execution. It's better to call it once before the loops and store the result in a variable. This will improve performance by avoiding redundant system calls.

Additionally, there are some minor style issues: there are trailing whitespaces and two spaces around == instead of one (platform.system() == 'Linux').

Suggested change
for name, param in module.named_parameters(recurse=recurse):
if platform.system() == 'Linux':
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
param.data = param.data.to(device=device)
for name, buffer in module.named_buffers(recurse=recurse):
buffer.data = buffer.data.pin_memory()
if platform.system() == 'Linux':
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data
buffer.data = buffer.data.to(device=device)
setattr(module, "_offload_param_dict", offload_param_dict)
is_linux = platform.system() == "Linux"
for name, param in module.named_parameters(recurse=recurse):
if is_linux:
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
param.data = param.data.to(device=device)
for name, buffer in module.named_buffers(recurse=recurse):
if is_linux:
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data
buffer.data = buffer.data.to(device=device)
setattr(module, "_offload_param_dict", offload_param_dict)

Expand All @@ -58,10 +59,12 @@ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
module = module.to("cpu")
offload_param_dict = {}
for name, param in module.named_parameters(recurse=True):
param.data = param.data.pin_memory()
if platform.system() == 'Linux':
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
for name, buffer in module.named_buffers(recurse=True):
buffer.data = buffer.data.pin_memory()
if platform.system() == 'Linux':
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data
Comment on lines 61 to 68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, platform.system() is being called inside loops, which is inefficient. It's better to perform this check once before the loops.

Also, there are some minor style issues: there are trailing whitespaces and two spaces around == instead of one (platform.system() == 'Linux').

Suggested change
for name, param in module.named_parameters(recurse=True):
param.data = param.data.pin_memory()
if platform.system() == 'Linux':
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
for name, buffer in module.named_buffers(recurse=True):
buffer.data = buffer.data.pin_memory()
if platform.system() == 'Linux':
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data
is_linux = platform.system() == "Linux"
for name, param in module.named_parameters(recurse=True):
if is_linux:
param.data = param.data.pin_memory()
offload_param_dict[name] = param.data
for name, buffer in module.named_buffers(recurse=True):
if is_linux:
buffer.data = buffer.data.pin_memory()
offload_param_dict[name] = buffer.data

return offload_param_dict

Expand Down