Skip to content

Commit 7733d51

Browse files
authored
try fix flux2 (#9)
1 parent 96c7f18 commit 7733d51

File tree

1 file changed

+8
-42
lines changed

1 file changed

+8
-42
lines changed

comfy/model_patcher.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from comfy.comfy_types import UnetWrapperFunction
4242
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
4343
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
44+
from comfy.quant_ops import QuantizedTensor
4445

4546
def need_mmap() -> bool:
4647
free_cpu_mem = get_free_memory(torch.device("cpu"))
@@ -54,19 +55,14 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
5455
"""
5556
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
5657
"""
57-
# Move to CPU if needed
58-
if t.is_cuda:
59-
cpu_tensor = t.cpu()
60-
else:
61-
cpu_tensor = t
62-
6358
# Create temporary file
6459
if filename is None:
6560
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
6661
else:
6762
temp_file = filename
6863

6964
# Save tensor to file
65+
cpu_tensor = t.cpu()
7066
torch.save(cpu_tensor, temp_file)
7167

7268
# If we created a CPU copy from CUDA, delete it to free memory
@@ -89,37 +85,7 @@ def _cleanup():
8985
pass
9086

9187
weakref.finalize(mmap_tensor, _cleanup)
92-
93-
# # Save original 'to' method
94-
# original_to = mmap_tensor.to
95-
96-
# # Create custom 'to' method that cleans up file when moving to CUDA
97-
# def custom_to(*args, **kwargs):
98-
# # Determine target device
99-
# target_device = None
100-
# if len(args) > 0:
101-
# if isinstance(args[0], torch.device):
102-
# target_device = args[0]
103-
# elif isinstance(args[0], str):
104-
# target_device = torch.device(args[0])
105-
# if 'device' in kwargs:
106-
# target_device = kwargs['device']
107-
# if isinstance(target_device, str):
108-
# target_device = torch.device(target_device)
109-
#
110-
# # Call original 'to' method first to move data
111-
# result = original_to(*args, **kwargs)
112-
#
113-
# # NOTE: Cleanup disabled to avoid blocking model load performance
114-
# # If moved to CUDA, cleanup the mmap file after the move
115-
# if target_device is not None and target_device.type == 'cuda':
116-
# _cleanup()
117-
#
118-
# return result
119-
120-
# # Replace the 'to' method
121-
# mmap_tensor.to = custom_to
122-
88+
12389
return mmap_tensor
12490

12591
def model_to_mmap(model: torch.nn.Module):
@@ -149,13 +115,13 @@ def convert_fn(t):
149115
- For Parameters: modify .data and return the Parameter object
150116
- For buffers (plain Tensors): return new MemoryMappedTensor
151117
"""
152-
if isinstance(t, torch.nn.Parameter):
153-
# For parameters, modify data in-place and return the parameter
154-
if isinstance(t.data, torch.Tensor):
155-
t.data = to_mmap(t.data)
118+
if isinstance(t, QuantizedTensor):
119+
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
156120
return t
121+
elif isinstance(t, torch.nn.Parameter):
122+
new_tensor = to_mmap(t.detach())
123+
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
157124
elif isinstance(t, torch.Tensor):
158-
# For buffers (plain tensors), return the converted tensor
159125
return to_mmap(t)
160126
return t
161127

0 commit comments

Comments
 (0)