Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,9 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None

tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)

input_shape = input.shape
input_dtype = input.dtype

if len(input.shape) == 3:
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)

scale_weight = self.scale_weight
Expand All @@ -422,24 +416,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)

# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)

uncast_bias_weight(self, w, bias, offload_stream)

if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return o

return None

Expand Down Expand Up @@ -540,12 +530,12 @@ def forward(self, *args, **kwargs):
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
from .quant_ops import QuantizedTensor

QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
Expand Down
27 changes: 22 additions & 5 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __new__(cls, qdata, layout_type, layout_params):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)

def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
Expand Down Expand Up @@ -183,11 +183,11 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):

@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)

def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params)
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Expand Down Expand Up @@ -379,7 +379,12 @@ def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']


@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}


@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
Expand Down Expand Up @@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output

Expand All @@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize()

return torch.nn.functional.linear(input_tensor, weight, bias)


@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)
8 changes: 4 additions & 4 deletions tests-unit/comfy_quant/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def has_gpu():
args.cpu = True

from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
from comfy.quant_ops import QuantizedTensor


class SimpleModel(torch.nn.Module):
Expand Down Expand Up @@ -104,14 +104,14 @@ def test_mixed_precision_load(self):

# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")

# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)

# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")

# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_state_dict_quantized_preserved(self):
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")

# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
Expand Down
20 changes: 10 additions & 10 deletions tests-unit/comfy_quant/test_quant_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def test_creation(self):
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}

qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")

def test_dequantize(self):
"""Test explicit dequantization"""
Expand All @@ -41,7 +41,7 @@ def test_dequantize(self):
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}

qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize()

self.assertEqual(dequantized.dtype, torch.float32)
Expand All @@ -54,7 +54,7 @@ def test_from_float(self):

qt = QuantizedTensor.from_float(
float_tensor,
TensorCoreFP8Layout,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
Expand All @@ -77,28 +77,28 @@ def test_detach(self):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

# Detach should return a new QuantizedTensor
qt_detached = qt.detach()

self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")

def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()

self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")

# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
Expand All @@ -109,7 +109,7 @@ def test_to_device(self):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_unsupported_op_dequantizes(self):
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
TensorCoreFP8Layout,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
Expand Down