4141from comfy .comfy_types import UnetWrapperFunction
4242from comfy .patcher_extension import CallbacksMP , PatcherInjection , WrappersMP
4343from comfy .model_management import get_free_memory , get_mmap_mem_threshold_gb , get_free_disk
44+ from comfy .quant_ops import QuantizedTensor
4445
4546def need_mmap () -> bool :
4647 free_cpu_mem = get_free_memory (torch .device ("cpu" ))
@@ -54,19 +55,14 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
5455 """
5556 Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
5657 """
57- # Move to CPU if needed
58- if t .is_cuda :
59- cpu_tensor = t .cpu ()
60- else :
61- cpu_tensor = t
62-
6358 # Create temporary file
6459 if filename is None :
6560 temp_file = tempfile .mktemp (suffix = '.pt' , prefix = 'comfy_mmap_' )
6661 else :
6762 temp_file = filename
6863
6964 # Save tensor to file
65+ cpu_tensor = t .cpu ()
7066 torch .save (cpu_tensor , temp_file )
7167
7268 # If we created a CPU copy from CUDA, delete it to free memory
@@ -89,37 +85,7 @@ def _cleanup():
8985 pass
9086
9187 weakref .finalize (mmap_tensor , _cleanup )
92-
93- # # Save original 'to' method
94- # original_to = mmap_tensor.to
95-
96- # # Create custom 'to' method that cleans up file when moving to CUDA
97- # def custom_to(*args, **kwargs):
98- # # Determine target device
99- # target_device = None
100- # if len(args) > 0:
101- # if isinstance(args[0], torch.device):
102- # target_device = args[0]
103- # elif isinstance(args[0], str):
104- # target_device = torch.device(args[0])
105- # if 'device' in kwargs:
106- # target_device = kwargs['device']
107- # if isinstance(target_device, str):
108- # target_device = torch.device(target_device)
109- #
110- # # Call original 'to' method first to move data
111- # result = original_to(*args, **kwargs)
112- #
113- # # NOTE: Cleanup disabled to avoid blocking model load performance
114- # # If moved to CUDA, cleanup the mmap file after the move
115- # if target_device is not None and target_device.type == 'cuda':
116- # _cleanup()
117- #
118- # return result
119-
120- # # Replace the 'to' method
121- # mmap_tensor.to = custom_to
122-
88+
12389 return mmap_tensor
12490
12591def model_to_mmap (model : torch .nn .Module ):
@@ -149,13 +115,13 @@ def convert_fn(t):
149115 - For Parameters: modify .data and return the Parameter object
150116 - For buffers (plain Tensors): return new MemoryMappedTensor
151117 """
152- if isinstance (t , torch .nn .Parameter ):
153- # For parameters, modify data in-place and return the parameter
154- if isinstance (t .data , torch .Tensor ):
155- t .data = to_mmap (t .data )
118+ if isinstance (t , QuantizedTensor ):
119+ logging .debug (f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size { t .size ()} , dtype { t .dtype } , device { t .device } , is_contiguous { t .is_contiguous ()} " )
156120 return t
121+ elif isinstance (t , torch .nn .Parameter ):
122+ new_tensor = to_mmap (t .detach ())
123+ return torch .nn .Parameter (new_tensor , requires_grad = t .requires_grad )
157124 elif isinstance (t , torch .Tensor ):
158- # For buffers (plain tensors), return the converted tensor
159125 return to_mmap (t )
160126 return t
161127
0 commit comments