Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor

def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
Expand All @@ -54,19 +55,14 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
# Move to CPU if needed
if t.is_cuda:
cpu_tensor = t.cpu()
else:
cpu_tensor = t

# Create temporary file
if filename is None:
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
else:
temp_file = filename

# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from CUDA, delete it to free memory
Expand All @@ -89,37 +85,7 @@ def _cleanup():
pass

weakref.finalize(mmap_tensor, _cleanup)

# # Save original 'to' method
# original_to = mmap_tensor.to

# # Create custom 'to' method that cleans up file when moving to CUDA
# def custom_to(*args, **kwargs):
# # Determine target device
# target_device = None
# if len(args) > 0:
# if isinstance(args[0], torch.device):
# target_device = args[0]
# elif isinstance(args[0], str):
# target_device = torch.device(args[0])
# if 'device' in kwargs:
# target_device = kwargs['device']
# if isinstance(target_device, str):
# target_device = torch.device(target_device)
#
# # Call original 'to' method first to move data
# result = original_to(*args, **kwargs)
#
# # NOTE: Cleanup disabled to avoid blocking model load performance
# # If moved to CUDA, cleanup the mmap file after the move
# if target_device is not None and target_device.type == 'cuda':
# _cleanup()
#
# return result

# # Replace the 'to' method
# mmap_tensor.to = custom_to


return mmap_tensor

def model_to_mmap(model: torch.nn.Module):
Expand Down Expand Up @@ -149,13 +115,13 @@ def convert_fn(t):
- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, torch.nn.Parameter):
# For parameters, modify data in-place and return the parameter
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
if isinstance(t, QuantizedTensor):
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
return t
elif isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t)
return t

Expand Down
Loading