@@ -34,7 +34,7 @@ def infer_module_output_dtypes(
34
34
"""
35
35
outputs = [node for node in module .graph .nodes if node .op == "output" ]
36
36
outputs = outputs [0 ].args
37
- return get_output_dtypes (outputs , truncate_double ) # type: ignore[no-any-return]
37
+ return get_output_dtypes (outputs , truncate_double )
38
38
39
39
40
40
def interpret_module_to_result (
@@ -70,6 +70,29 @@ def interpret_module_to_result(
70
70
)
71
71
72
72
interpreter_result = interpreter .run ()
73
+ # Delete the frozen parameters from the module to release CPU memory
74
+ del interpreter
75
+ for attr in dir (module ):
76
+ if attr .startswith ("_frozen_param" ):
77
+ delattr (module , attr )
78
+ release_memory ()
79
+ logger .debug (
80
+ f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
81
+ )
82
+
83
+ serialized_engine = interpreter_result .engine .serialize ()
84
+ with io .BytesIO () as engine_bytes :
85
+ engine_bytes .write (serialized_engine )
86
+ serialized_engine = engine_bytes .getvalue ()
87
+
88
+ interpreter_result = TRTInterpreterResult (
89
+ engine = serialized_engine ,
90
+ input_names = interpreter_result .input_names ,
91
+ output_names = interpreter_result .output_names ,
92
+ weight_name_map = interpreter_result .weight_name_map ,
93
+ requires_output_allocator = interpreter_result .requires_output_allocator ,
94
+ )
95
+
73
96
return interpreter_result
74
97
75
98
@@ -108,22 +131,8 @@ def convert_module(
108
131
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
109
132
)
110
133
111
- # Delete the frozen parameters from the module to release CPU memory
112
- for attr in dir (module ):
113
- if attr .startswith ("_frozen_param" ):
114
- delattr (module , attr )
115
- release_memory ()
116
- logger .debug (
117
- f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
118
- )
119
-
120
- serialized_engine = interpreter_result .engine .serialize ()
121
- with io .BytesIO () as engine_bytes :
122
- engine_bytes .write (serialized_engine )
123
- serialized_engine = engine_bytes .getvalue ()
124
- breakpoint ()
125
134
return rt_cls (
126
- serialized_engine = serialized_engine ,
135
+ serialized_engine = interpreter_result . engine ,
127
136
input_binding_names = list (interpreter_result .input_names ),
128
137
output_binding_names = list (interpreter_result .output_names ),
129
138
name = name ,
0 commit comments