Skip to content

Commit

Permalink
Support official scaled fp8 model
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 18, 2024
1 parent 46f569f commit fb34e91
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 2 deletions.
Binary file added hyvideo/modules/fp8_map.safetensors
Binary file not shown.
106 changes: 106 additions & 0 deletions hyvideo/modules/fp8_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from comfy.utils import load_torch_file

def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
_bits = torch.tensor(bits)
_mantissa_bit = torch.tensor(mantissa_bit)
_sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
E = _bits - _sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
return maxval

def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
"""
Default is E4M3.
"""
bits = torch.tensor(bits)
mantissa_bit = torch.tensor(mantissa_bit)
sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
E = bits - sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
minval = - maxval
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
input_clamp = torch.min(torch.max(x, minval), maxval)
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
# dequant
qdq_out = torch.round(input_clamp / log_scales) * log_scales
return qdq_out, log_scales

def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
for i in range(len(x.shape) - 1):
scale = scale.unsqueeze(-1)
new_x = x / scale
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
return quant_dequant_x, scale, log_scales

def fp8_activation_dequant(qdq_out, scale, dtype):
qdq_out = qdq_out.type(dtype)
quant_dequant_x = qdq_out * scale.to(dtype)
return quant_dequant_x

def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
#####
if cls.weight.dtype != torch.float8_e4m3fn:
maxval = get_fp_maxval()
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
linear_weight = linear_weight.to(torch.float8_e4m3fn)
weight_dtype = linear_weight.dtype
else:
scale = cls.fp8_scale.to(cls.weight.device)
linear_weight = cls.weight
#####

if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
if True or len(input.shape) == 3:
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
if cls.bias != None:
output = F.linear(input, cls_dequant, cls.bias)
else:
output = F.linear(input, cls_dequant)
return output
else:
return cls.original_forward(input.to(original_dtype))
else:
return cls.original_forward(input)

def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True)
script_directory = os.path.dirname(os.path.abspath(__file__))

# loading fp8 mapping file
#fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
fp8_map_path = os.path.join(script_directory,"fp8_map.safetensors")
if os.path.exists(fp8_map_path):
#fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
fp8_map = load_torch_file(fp8_map_path)
else:
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")

fp8_layers = []
for key, layer in module.named_modules():
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
fp8_layers.append(key)
original_forward = layer.forward
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))


7 changes: 5 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def INPUT_TYPES(s):
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),

"base_precision": (["fp32", "bf16"], {"default": "bf16"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6", "torchao_int4", "torchao_int8"], {"default": 'disabled', "tooltip": "optional quantization method"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_scaled', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6", "torchao_int4", "torchao_int8"], {"default": 'disabled', "tooltip": "optional quantization method"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
},
"optional": {
Expand Down Expand Up @@ -324,7 +324,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,

if not "torchao" in quantization:
log.info("Using accelerate to load and assign model weights to device...")
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_scaled":
dtype = torch.float8_e4m3fn
else:
dtype = base_dtype
Expand Down Expand Up @@ -363,6 +363,9 @@ def loadmodel(self, model, base_precision, load_device, quantization,
if quantization == "fp8_e4m3fn_fast":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep)
elif quantization == "fp8_scaled":
from .hyvideo.modules.fp8_optimization import convert_fp8_linear
convert_fp8_linear(patcher.model.diffusion_model, base_dtype)

#compile
if compile_args is not None:
Expand Down

3 comments on commit fb34e91

@Ratinod
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it? I can't find where to get a scaled model.
Is this it? https://huggingface.co/tencent/HunyuanVideo/blob/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt
Maybe there is a .safetensors version somewhere?

@kijai
Copy link
Owner Author

@kijai kijai commented on fb34e91 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is that, the file is safe and loads with weights_only so no risk using it.

@maplecasino
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking your word that the file is safe. It's a pickle file: a pt (zip file containing a 101 kilobyte file named data.pkl), and pickle encoded is not guaranteed safe to load like safetensors is

Please sign in to comment.