Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
weight = self.convert_func(weight, inplace=False)

if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
Expand Down
22 changes: 21 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)

Expand Down Expand Up @@ -502,7 +504,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)

def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
Expand Down Expand Up @@ -643,6 +645,24 @@ def forward(self, input, *args, **kwargs):
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)

def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight

def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight

assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)

return MixedPrecisionOps

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
Expand Down
21 changes: 14 additions & 7 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import logging
from typing import Tuple, Dict
import comfy.float

_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
Expand Down Expand Up @@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype

if scale is None:
Expand All @@ -403,17 +404,23 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)

tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)

if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)

layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
return tensor, layout_params

@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions comfy/weight_adapter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def calculate_weight(
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
Expand Down