@@ -123,7 +123,6 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123
123
124
124
def __init__ (
125
125
self ,
126
- cuda_engine : trt .ICudaEngine = None ,
127
126
serialized_engine : Optional [bytes ] = None ,
128
127
input_binding_names : Optional [List [str ]] = None ,
129
128
output_binding_names : Optional [List [str ]] = None ,
@@ -183,19 +182,7 @@ def __init__(
183
182
# Unused currently - to be used by Dynamic Shape support implementation
184
183
self .memory_pool = None
185
184
186
- if cuda_engine :
187
- assert isinstance (
188
- cuda_engine , trt .ICudaEngine
189
- ), "Cuda engine must be a trt.ICudaEngine object"
190
- self .engine = cuda_engine
191
- elif serialized_engine :
192
- assert isinstance (
193
- serialized_engine , bytes
194
- ), "Serialized engine must be a bytes object"
195
- self .engine = serialized_engine
196
- else :
197
- raise ValueError ("Serialized engine or cuda engine must be provided" )
198
-
185
+ self .serialized_engine = serialized_engine
199
186
self .input_names = (
200
187
input_binding_names if input_binding_names is not None else []
201
188
)
@@ -217,6 +204,7 @@ def __init__(
217
204
else False
218
205
)
219
206
self .settings = settings
207
+ self .engine = None
220
208
self .weight_name_map = weight_name_map
221
209
self .target_platform = Platform .current_platform ()
222
210
self .runtime_states = TorchTRTRuntimeStates (
@@ -231,7 +219,7 @@ def __init__(
231
219
self .output_allocator : Optional [DynamicOutputAllocator ] = None
232
220
self .use_output_allocator_outputs = False
233
221
234
- if self .engine and not self .settings .lazy_engine_init :
222
+ if self .serialized_engine is not None and not self .settings .lazy_engine_init :
235
223
self .setup_engine ()
236
224
237
225
def get_streamable_device_memory_budget (self ) -> Any :
@@ -272,22 +260,13 @@ def set_default_device_memory_budget(self) -> int:
272
260
return self ._set_device_memory_budget (budget_bytes )
273
261
274
262
def setup_engine (self ) -> None :
275
-
276
- if isinstance (self .engine , trt .ICudaEngine ):
277
- pass
278
- elif isinstance (self .engine , bytes ):
279
- runtime = trt .Runtime (TRT_LOGGER )
280
- self .engine = runtime .deserialize_cuda_engine (self .engine )
281
- else :
282
- raise ValueError (
283
- "Expected engine as trt.ICudaEngine or serialized engine as bytes"
284
- )
285
-
286
263
assert (
287
264
self .target_platform == Platform .current_platform ()
288
265
), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
289
266
290
267
self .initialized = True
268
+ runtime = trt .Runtime (TRT_LOGGER )
269
+ self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
291
270
if self .settings .enable_weight_streaming :
292
271
self .set_default_device_memory_budget ()
293
272
self .context = self .engine .create_execution_context ()
@@ -323,7 +302,7 @@ def _check_initialized(self) -> None:
323
302
raise RuntimeError ("PythonTorchTensorRTModule is not initialized." )
324
303
325
304
def _on_state_dict (self , state_dict : Dict [str , Any ], prefix : str , _ : Any ) -> None :
326
- state_dict [prefix + "engine" ] = self .engine
305
+ state_dict [prefix + "engine" ] = self .serialized_engine
327
306
state_dict [prefix + "input_names" ] = self .input_names
328
307
state_dict [prefix + "output_names" ] = self .output_names
329
308
state_dict [prefix + "platform" ] = self .target_platform
@@ -338,7 +317,7 @@ def _load_from_state_dict(
338
317
unexpected_keys : Any ,
339
318
error_msgs : Any ,
340
319
) -> None :
341
- self .engine = state_dict [prefix + "engine" ]
320
+ self .serialized_engine = state_dict [prefix + "engine" ]
342
321
self .input_names = state_dict [prefix + "input_names" ]
343
322
self .output_names = state_dict [prefix + "output_names" ]
344
323
self .target_platform = state_dict [prefix + "platform" ]
0 commit comments