@@ -1082,8 +1082,20 @@ def cast_to_device(tensor, device, dtype, copy=False):
10821082 non_blocking = device_supports_non_blocking (device )
10831083 return cast_to (tensor , dtype = dtype , device = device , non_blocking = non_blocking , copy = copy )
10841084
1085+
1086+ PINNED_MEMORY = {}
1087+ TOTAL_PINNED_MEMORY = 0
1088+ if PerformanceFeature .PinnedMem in args .fast :
1089+ if WINDOWS :
1090+ MAX_PINNED_MEMORY = get_total_memory (torch .device ("cpu" )) * 0.45 # Windows limit is apparently 50%
1091+ else :
1092+ MAX_PINNED_MEMORY = get_total_memory (torch .device ("cpu" )) * 0.95
1093+ else :
1094+ MAX_PINNED_MEMORY = - 1
1095+
10851096def pin_memory (tensor ):
1086- if PerformanceFeature .PinnedMem not in args .fast :
1097+ global TOTAL_PINNED_MEMORY
1098+ if MAX_PINNED_MEMORY <= 0 :
10871099 return False
10881100
10891101 if not is_nvidia ():
@@ -1092,13 +1104,21 @@ def pin_memory(tensor):
10921104 if not is_device_cpu (tensor .device ):
10931105 return False
10941106
1095- if torch .cuda .cudart ().cudaHostRegister (tensor .data_ptr (), tensor .numel () * tensor .element_size (), 1 ) == 0 :
1107+ size = tensor .numel () * tensor .element_size ()
1108+ if (TOTAL_PINNED_MEMORY + size ) > MAX_PINNED_MEMORY :
1109+ return False
1110+
1111+ ptr = tensor .data_ptr ()
1112+ if torch .cuda .cudart ().cudaHostRegister (ptr , size , 1 ) == 0 :
1113+ PINNED_MEMORY [ptr ] = size
1114+ TOTAL_PINNED_MEMORY += size
10961115 return True
10971116
10981117 return False
10991118
11001119def unpin_memory (tensor ):
1101- if PerformanceFeature .PinnedMem not in args .fast :
1120+ global TOTAL_PINNED_MEMORY
1121+ if MAX_PINNED_MEMORY <= 0 :
11021122 return False
11031123
11041124 if not is_nvidia ():
@@ -1107,7 +1127,11 @@ def unpin_memory(tensor):
11071127 if not is_device_cpu (tensor .device ):
11081128 return False
11091129
1110- if torch .cuda .cudart ().cudaHostUnregister (tensor .data_ptr ()) == 0 :
1130+ ptr = tensor .data_ptr ()
1131+ if torch .cuda .cudart ().cudaHostUnregister (ptr ) == 0 :
1132+ TOTAL_PINNED_MEMORY -= PINNED_MEMORY .pop (ptr )
1133+ if len (PINNED_MEMORY ) == 0 :
1134+ TOTAL_PINNED_MEMORY = 0
11111135 return True
11121136
11131137 return False
0 commit comments