Skip to content

Commit 7f3e4d4

Browse files
Limit amount of pinned memory on windows to prevent issues. (#10638)
1 parent a389ee0 commit 7f3e4d4

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

comfy/model_management.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,8 +1082,20 @@ def cast_to_device(tensor, device, dtype, copy=False):
10821082
non_blocking = device_supports_non_blocking(device)
10831083
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
10841084

1085+
1086+
PINNED_MEMORY = {}
1087+
TOTAL_PINNED_MEMORY = 0
1088+
if PerformanceFeature.PinnedMem in args.fast:
1089+
if WINDOWS:
1090+
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
1091+
else:
1092+
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
1093+
else:
1094+
MAX_PINNED_MEMORY = -1
1095+
10851096
def pin_memory(tensor):
1086-
if PerformanceFeature.PinnedMem not in args.fast:
1097+
global TOTAL_PINNED_MEMORY
1098+
if MAX_PINNED_MEMORY <= 0:
10871099
return False
10881100

10891101
if not is_nvidia():
@@ -1092,13 +1104,21 @@ def pin_memory(tensor):
10921104
if not is_device_cpu(tensor.device):
10931105
return False
10941106

1095-
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
1107+
size = tensor.numel() * tensor.element_size()
1108+
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
1109+
return False
1110+
1111+
ptr = tensor.data_ptr()
1112+
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
1113+
PINNED_MEMORY[ptr] = size
1114+
TOTAL_PINNED_MEMORY += size
10961115
return True
10971116

10981117
return False
10991118

11001119
def unpin_memory(tensor):
1101-
if PerformanceFeature.PinnedMem not in args.fast:
1120+
global TOTAL_PINNED_MEMORY
1121+
if MAX_PINNED_MEMORY <= 0:
11021122
return False
11031123

11041124
if not is_nvidia():
@@ -1107,7 +1127,11 @@ def unpin_memory(tensor):
11071127
if not is_device_cpu(tensor.device):
11081128
return False
11091129

1110-
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
1130+
ptr = tensor.data_ptr()
1131+
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
1132+
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
1133+
if len(PINNED_MEMORY) == 0:
1134+
TOTAL_PINNED_MEMORY = 0
11111135
return True
11121136

11131137
return False

0 commit comments

Comments
 (0)