Skip to content

Commit 66b40bd

Browse files
committed
Added a potential solution for windows
1 parent d99f183 commit 66b40bd

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def infer_module_output_dtypes(
3434
"""
3535
outputs = [node for node in module.graph.nodes if node.op == "output"]
3636
outputs = outputs[0].args
37-
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
37+
return get_output_dtypes(outputs, truncate_double)
3838

3939

4040
def interpret_module_to_result(
@@ -70,6 +70,29 @@ def interpret_module_to_result(
7070
)
7171

7272
interpreter_result = interpreter.run()
73+
# Delete the frozen parameters from the module to release CPU memory
74+
del interpreter
75+
for attr in dir(module):
76+
if attr.startswith("_frozen_param"):
77+
delattr(module, attr)
78+
release_memory()
79+
logger.debug(
80+
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
81+
)
82+
83+
serialized_engine = interpreter_result.engine.serialize()
84+
with io.BytesIO() as engine_bytes:
85+
engine_bytes.write(serialized_engine)
86+
serialized_engine = engine_bytes.getvalue()
87+
88+
interpreter_result = TRTInterpreterResult(
89+
engine=serialized_engine,
90+
input_names=interpreter_result.input_names,
91+
output_names=interpreter_result.output_names,
92+
weight_name_map=interpreter_result.weight_name_map,
93+
requires_output_allocator=interpreter_result.requires_output_allocator,
94+
)
95+
7396
return interpreter_result
7497

7598

@@ -108,22 +131,8 @@ def convert_module(
108131
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
109132
)
110133

111-
# Delete the frozen parameters from the module to release CPU memory
112-
for attr in dir(module):
113-
if attr.startswith("_frozen_param"):
114-
delattr(module, attr)
115-
release_memory()
116-
logger.debug(
117-
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
118-
)
119-
120-
serialized_engine = interpreter_result.engine.serialize()
121-
with io.BytesIO() as engine_bytes:
122-
engine_bytes.write(serialized_engine)
123-
serialized_engine = engine_bytes.getvalue()
124-
breakpoint()
125134
return rt_cls(
126-
serialized_engine=serialized_engine,
135+
serialized_engine=interpreter_result.engine,
127136
input_binding_names=list(interpreter_result.input_names),
128137
output_binding_names=list(interpreter_result.output_names),
129138
name=name,

py/torch_tensorrt/dynamo/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@ def get_cpu_memory_usage() -> Any:
868868

869869

870870
def release_memory() -> None:
871+
gc.collect()
871872
if torch.cuda.is_available():
872873
torch.cuda.synchronize()
873874
torch.cuda.empty_cache()
@@ -881,3 +882,32 @@ def release_memory() -> None:
881882
logger.warning("Failed to release CPU memory.")
882883
except Exception:
883884
logger.warning("Failed to release CPU memory.")
885+
886+
elif platform.system() == "Windows":
887+
from ctypes import wintypes
888+
889+
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
890+
psapi = ctypes.WinDLL("psapi", use_last_error=True)
891+
892+
GetCurrentProcess = kernel32.GetCurrentProcess
893+
GetCurrentProcess.restype = wintypes.HANDLE
894+
hproc = GetCurrentProcess()
895+
896+
HeapSetInformation = kernel32.HeapSetInformation
897+
HeapSetInformation.argtypes = [
898+
wintypes.HANDLE,
899+
ctypes.c_int,
900+
ctypes.c_void_p,
901+
ctypes.c_size_t,
902+
]
903+
HeapSetInformation.restype = wintypes.BOOL
904+
GetProcessHeap = kernel32.GetProcessHeap
905+
GetProcessHeap.restype = wintypes.HANDLE
906+
ok = False
907+
try:
908+
HeapOptimizeResources = 3
909+
hheap = GetProcessHeap()
910+
if HeapSetInformation(hheap, HeapOptimizeResources, None, 0):
911+
ok = True
912+
except Exception:
913+
logger.warning("Failed to release CPU memory.")

0 commit comments

Comments
 (0)