Skip to content

Commit

Permalink
Add alternative cpu offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 23, 2024
1 parent 2fe649b commit f52a49f
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 3 deletions.
196 changes: 196 additions & 0 deletions hyvideo/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,58 @@
from .token_refiner import SingleTokenRefiner
from ...enhance_a_video.enhance import get_feta_scores
from ...enhance_a_video.globals import is_enhance_enabled_single, is_enhance_enabled_double, set_num_frames
from .norm_layers import RMSNorm

from contextlib import contextmanager

@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):

old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer

def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)

def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)

return wrapper

if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}

try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)

class MMDoubleStreamBlock(nn.Module):
"""
Expand Down Expand Up @@ -644,6 +696,150 @@ def block_swap(self, double_blocks_to_swap, single_blocks_to_swap, offload_txt_i
else:
block.to(self.offload_device)

def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)

r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r

def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight

def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
return weight, bias

class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device

def block_forward_(self, x, i, j, dtype, device):
weight_ = cast_to(
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
dtype=dtype, device=device
)
if self.bias is None or i > 0:
bias_ = None
else:
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
y_ = torch.nn.functional.linear(x_, weight_, bias_)
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_

def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
for i in range((self.in_features + self.block_size - 1) // self.block_size):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y

def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)


class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device

def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype)
if self.module.weight is not None:
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states

class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device

def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)

class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device

def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(
module.in_features, module.out_features, bias=module.bias is not None,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.Conv3d):
with init_weights_on_device():
new_layer = quantized_layer.Conv3d(
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(
module,
dtype=dtype, device=device
)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.LayerNorm):
with init_weights_on_device():
new_layer = quantized_layer.LayerNorm(
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
else:
replace_layer(module, dtype=dtype, device=device)

replace_layer(self, dtype=dtype, device=device)

def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
Expand Down
3 changes: 2 additions & 1 deletion hyvideo/modules/token_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def forward(
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
):
in_dtype = x.dtype
timestep_aware_representations = self.t_embedder(t)

if mask is None:
Expand All @@ -226,7 +227,7 @@ def forward(
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
context_aware_representations = self.c_embedder(context_aware_representations.to(in_dtype))
c = timestep_aware_representations + context_aware_representations

x = self.input_embedder(x)
Expand Down
12 changes: 10 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def INPUT_TYPES(s):
"compile_args": ("COMPILEARGS", ),
"block_swap_args": ("BLOCKSWAPARGS", ),
"lora": ("HYVIDLORA", {"default": None}),
"auto_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Enable auto offloading for reduced VRAM usage, implementation from DiffSynth-Studio, slightly different from block swapping and uses even less VRAM, but can be slower as you can't define how much VRAM to use"}),
}
}

Expand All @@ -233,7 +234,7 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"

def loadmodel(self, model, base_precision, load_device, quantization,
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None):
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, auto_cpu_offload=False):
transformer = None
#mm.unload_all_models()
mm.soft_empty_cache()
Expand Down Expand Up @@ -339,6 +340,9 @@ def loadmodel(self, model, base_precision, load_device, quantization,
from .hyvideo.modules.fp8_optimization import convert_fp8_linear
convert_fp8_linear(patcher.model.diffusion_model, base_dtype)

if auto_cpu_offload:
transformer.enable_auto_offload(dtype=dtype, device=device)

#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
Expand Down Expand Up @@ -440,6 +444,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,
patcher.model["manual_offloading"] = manual_offloading
patcher.model["quantization"] = "disabled"
patcher.model["block_swap_args"] = block_swap_args
patcher.model["auto_offload"] = auto_cpu_offload

return (patcher,)

Expand Down Expand Up @@ -1117,7 +1122,10 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal
offload_txt_in = model["block_swap_args"]["offload_txt_in"],
offload_img_in = model["block_swap_args"]["offload_img_in"],
)

elif model["auto_cpu_offload"]:
for name, param in transformer.named_parameters():
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
elif model["manual_offloading"]:
transformer.to(device)

Expand Down

0 comments on commit f52a49f

Please sign in to comment.