Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6e33ee3
debug error
strint Oct 16, 2025
fa19dd4
debug offload
strint Oct 16, 2025
f40e00c
add detail debug
strint Oct 16, 2025
2b22296
add debug log
strint Oct 16, 2025
c1eac55
add debug log
strint Oct 16, 2025
9352987
add log
strint Oct 16, 2025
a207301
rm useless log
strint Oct 16, 2025
71b23d1
rm useless log
strint Oct 16, 2025
e5ff6a1
refine log
strint Oct 16, 2025
5c3c6c0
add debug log of cpu load
strint Oct 17, 2025
6583cc0
debug load mem
strint Oct 17, 2025
49597bf
load remains mmap
strint Oct 17, 2025
21ebcad
debug free mem
strint Oct 20, 2025
4ac827d
unload partial
strint Oct 20, 2025
e9e1d2f
add mmap tensor
strint Oct 20, 2025
4956178
fix log
strint Oct 20, 2025
8aeebbf
fix to
strint Oct 20, 2025
05c2518
refact mmap
strint Oct 20, 2025
2f0d566
refine code
strint Oct 21, 2025
2d010f5
refine code
strint Oct 21, 2025
fff56de
fix format
strint Oct 21, 2025
08e094e
use native mmap
strint Oct 21, 2025
8038393
lazy rm file
strint Oct 21, 2025
98ba311
add env
strint Oct 21, 2025
f3c673d
Merge branch 'master' of https://github.com/siliconflow/ComfyUI into …
strint Oct 22, 2025
aab0e24
fix MMAP_MEM_THRESHOLD_GB default
strint Oct 23, 2025
58d28ed
no limit for offload size
strint Oct 23, 2025
c312733
refine log
strint Oct 23, 2025
dc7c77e
better partial unload
strint Oct 23, 2025
5c5fbdd
debug mmap
strint Nov 17, 2025
d28093f
Merge branch 'master' into refine_offload
doombeaker Nov 26, 2025
96c7f18
Merge branch 'master' into refine_offload
doombeaker Nov 27, 2025
7733d51
try fix flux2 (#9)
strint Dec 4, 2025
211fa31
Merge branch 'master' into refine_offload
doombeaker Dec 8, 2025
1122cd0
allow offload quant (#10)
strint Dec 9, 2025
532eb01
rm comment
strint Dec 9, 2025
2c5b9da
rm debug log
strint Dec 12, 2025
5495b55
rm useless
strint Dec 12, 2025
fa674cc
refine
strint Dec 15, 2025
8b433f2
Merge branch 'master' into offload_to_mmap
strint Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def from_string(cls, value: str):
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")

parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--offload-reserve-ram-gb", type=float, default=None, help="Set the amount of ram in GB you want to reserve for other use. When the limit is reached, model on vram will be offloaded to mmap to save ram.")

parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
Expand Down
61 changes: 50 additions & 11 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@
import platform
import weakref
import gc
import os

from functools import lru_cache

@lru_cache(maxsize=1)
def get_offload_reserve_ram_gb():
offload_reserve_ram_gb = 0
try:
val = getattr(args, 'offload-reserve-ram-gb', None)
except Exception:
val = None

if val is not None:
try:
offload_reserve_ram_gb = int(val)
except Exception:
logging.warning(f"Invalid args.offload-reserve-ram-gb value: {val}, defaulting to 0")
offload_reserve_ram_gb= 0
return offload_reserve_ram_gb

def get_free_disk():
return psutil.disk_usage("/").free

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -521,16 +543,33 @@ def should_reload_model(self, force_patch_weights=False):
return False

def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
if memory_to_free is None:
# free the full model
memory_to_free = self.model.loaded_size()

available_memory = get_free_memory(self.model.offload_device)

mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size():
partially_unload = True
else:
partially_unload = False

if partially_unload:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed < memory_to_free:
logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
else:
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None

if partially_unload:
return False
else:
return True


def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
Expand Down Expand Up @@ -584,7 +623,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
can_unload = []
unloaded_models = []

for i in range(len(current_loaded_models) -1, -1, -1):
for i in range(len(current_loaded_models) -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
Expand Down
96 changes: 94 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from typing import Callable, Optional

import torch
import os
import tempfile
import weakref
import gc

import comfy.float
import comfy.hooks
Expand All @@ -37,6 +41,80 @@
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor

def enable_offload_to_mmap() -> bool:
if comfy.utils.DISABLE_MMAP:
return False

free_cpu_mem = get_free_memory(torch.device("cpu"))
offload_reserve_ram_gb = get_offload_reserve_ram_gb()
if free_cpu_mem <= offload_reserve_ram_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling offload to mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {offload_reserve_ram_gb} GB")
return True

return False

def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
else:
temp_file = filename

# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)

# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
try:
if os.path.exists(temp_file):
os.remove(temp_file)
logging.debug(f"Cleaned up mmap file: {temp_file}")
except Exception:
pass

weakref.finalize(mmap_tensor, _cleanup)

return mmap_tensor

def model_to_mmap(model: torch.nn.Module):
"""Convert all parameters and buffers to memory-mapped tensors

Args:
model: PyTorch module to convert

Returns:
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")

def convert_fn(t):
if isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):
return to_mmap(t)
return t

new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model


def string_to_seed(data):
Expand Down Expand Up @@ -506,6 +584,7 @@ def get_model_object(self, name: str) -> torch.nn.Module:
return comfy.utils.get_attr(self.model, name)

def model_patches_to(self, device):
# TODO(sf): to mmap
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
Expand Down Expand Up @@ -853,9 +932,15 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
self.model.current_weight_patches_uuid = None
self.backup.clear()


if device_to is not None:
self.model.to(device_to)
if enable_offload_to_mmap():
# offload to mmap
model_to_mmap(self.model)
else:
self.model.to(device_to)
self.model.device = device_to

self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0

Expand Down Expand Up @@ -914,7 +999,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
if enable_offload_to_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
Expand Down
45 changes: 44 additions & 1 deletion comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,19 @@ def __new__(cls, qdata, layout_type, layout_params):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
# Use as_subclass so the QuantizedTensor instance shares the same
# storage and metadata as the underlying qdata tensor. This ensures
# torch.save/torch.load and the torch serialization storage scanning
# see a valid underlying storage (fixes data_ptr errors).
if not isinstance(qdata, torch.Tensor):
raise TypeError("qdata must be a torch.Tensor")
obj = qdata.as_subclass(cls)
# Ensure grad flag is consistent for quantized tensors
try:
obj.requires_grad_(False)
except Exception:
pass
return obj

def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
Expand Down Expand Up @@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
return QuantizedTensor(qdata, layout_type, layout_params)


def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.

qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
inner tensor. We call the provided rebuild function with its args to recreate the
inner tensor, then wrap it in QuantizedTensor.
"""
rebuild_fn, rebuild_args = qdata_reduce
qdata = rebuild_fn(*rebuild_args)
return QuantizedTensor(qdata, layout_type, layout_params)


# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
try:
import torch as _torch_serial
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
_torch_serial.serialization.add_safe_globals([
QuantizedTensor,
_rebuild_quantized_tensor,
_rebuild_quantized_tensor_from_base,
])
except Exception:
# If add_safe_globals doesn't exist or registration fails, we silently continue.
pass
23 changes: 23 additions & 0 deletions tests-unit/comfy_quant/test_quant_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@ def test_dequantize(self):
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))

def test_save_load(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}

qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)

self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")

torch.save(qt, "test.pt")
loaded_qt = torch.load("test.pt", weights_only=False)
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)

self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
self.assertEqual(loaded_qt._layout_params['scale'], scale)
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)

def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
Expand Down
Loading