Skip to content

Commit bdb10a5

Browse files
Fix loras not working on mixed fp8. (#10899)
1 parent 0e24dbb commit bdb10a5

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

comfy/model_patcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
132132
def __call__(self, weight):
133133
intermediate_dtype = weight.dtype
134134
if self.convert_func is not None:
135-
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
135+
weight = self.convert_func(weight, inplace=False)
136136

137137
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
138138
intermediate_dtype = torch.float32

comfy/ops.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
117117
if weight_has_function or weight.dtype != dtype:
118118
with wf_context:
119119
weight = weight.to(dtype=dtype)
120+
if isinstance(weight, QuantizedTensor):
121+
weight = weight.dequantize()
120122
for f in s.weight_function:
121123
weight = f(weight)
122124

@@ -502,7 +504,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
502504
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
503505
return weight
504506
else:
505-
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
507+
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
506508

507509
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
508510
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@@ -643,6 +645,24 @@ def forward(self, input, *args, **kwargs):
643645
not isinstance(input, QuantizedTensor)):
644646
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
645647
return self._forward(input, self.weight, self.bias)
648+
649+
def convert_weight(self, weight, inplace=False, **kwargs):
650+
if isinstance(weight, QuantizedTensor):
651+
return weight.dequantize()
652+
else:
653+
return weight
654+
655+
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
656+
if getattr(self, 'layout_type', None) is not None:
657+
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
658+
else:
659+
weight = weight.to(self.weight.dtype)
660+
if return_weight:
661+
return weight
662+
663+
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
664+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
665+
646666
return MixedPrecisionOps
647667

648668
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):

comfy/quant_ops.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import logging
33
from typing import Tuple, Dict
4+
import comfy.float
45

56
_LAYOUT_REGISTRY = {}
67
_GENERIC_UTILS = {}
@@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
393394
- orig_dtype: Original dtype before quantization (for casting back)
394395
"""
395396
@classmethod
396-
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
397+
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
397398
orig_dtype = tensor.dtype
398399

399400
if scale is None:
@@ -403,17 +404,23 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
403404
scale = torch.tensor(scale)
404405
scale = scale.to(device=tensor.device, dtype=torch.float32)
405406

406-
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
407-
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
408-
lp_amax = torch.finfo(dtype).max
409-
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
410-
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
407+
if inplace_ops:
408+
tensor *= (1.0 / scale).to(tensor.dtype)
409+
else:
410+
tensor = tensor * (1.0 / scale).to(tensor.dtype)
411+
412+
if stochastic_rounding > 0:
413+
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
414+
else:
415+
lp_amax = torch.finfo(dtype).max
416+
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
417+
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
411418

412419
layout_params = {
413420
'scale': scale,
414421
'orig_dtype': orig_dtype
415422
}
416-
return qdata, layout_params
423+
return tensor, layout_params
417424

418425
@staticmethod
419426
def dequantize(qdata, scale, orig_dtype, **kwargs):

comfy/weight_adapter/lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def calculate_weight(
194194
lora_diff = torch.mm(
195195
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
196196
).reshape(weight.shape)
197+
del mat1, mat2
197198
if dora_scale is not None:
198199
weight = weight_decompose(
199200
dora_scale,

0 commit comments

Comments
 (0)