Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,18 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
else:
temp_file = filename

# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from CUDA, delete it to free memory
if t.is_cuda:
# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
Expand Down Expand Up @@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module):
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")

def convert_fn(t):
"""Convert function for _apply()

- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, QuantizedTensor):
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
return t
elif isinstance(t, torch.nn.Parameter):
logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
if isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):
Expand Down
45 changes: 44 additions & 1 deletion comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,19 @@ def __new__(cls, qdata, layout_type, layout_params):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
# Use as_subclass so the QuantizedTensor instance shares the same
# storage and metadata as the underlying qdata tensor. This ensures
# torch.save/torch.load and the torch serialization storage scanning
# see a valid underlying storage (fixes data_ptr errors).
if not isinstance(qdata, torch.Tensor):
raise TypeError("qdata must be a torch.Tensor")
obj = qdata.as_subclass(cls)
# Ensure grad flag is consistent for quantized tensors
try:
obj.requires_grad_(False)
except Exception:
pass
return obj

def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
Expand Down Expand Up @@ -570,3 +582,34 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
return QuantizedTensor(qdata, layout_type, layout_params)


def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.

qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
inner tensor. We call the provided rebuild function with its args to recreate the
inner tensor, then wrap it in QuantizedTensor.
"""
rebuild_fn, rebuild_args = qdata_reduce
qdata = rebuild_fn(*rebuild_args)
return QuantizedTensor(qdata, layout_type, layout_params)


# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
try:
import torch as _torch_serial
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
_torch_serial.serialization.add_safe_globals([
QuantizedTensor,
_rebuild_quantized_tensor,
_rebuild_quantized_tensor_from_base,
])
except Exception:
# If add_safe_globals doesn't exist or registration fails, we silently continue.
pass
23 changes: 23 additions & 0 deletions tests-unit/comfy_quant/test_quant_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@ def test_dequantize(self):
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))

def test_save_load(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}

qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")

torch.save(qt, "test.pt")
loaded_qt = torch.load("test.pt", weights_only=False)
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)

self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
self.assertEqual(loaded_qt._layout_params['scale'], scale)
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)

def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import os
import gc
import tempfile
import sys

# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from comfy.model_patcher import model_to_mmap, to_mmap


Expand Down
Loading