Skip to content

Commit d99f183

Browse files
committed
Revert back to support lazy init while reducing the memory consumption
1 parent 33ca588 commit d99f183

File tree

4 files changed

+24
-58
lines changed

4 files changed

+24
-58
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,11 +734,8 @@ def run(
734734
if interpreter_result is not None: # hit the cache
735735
return interpreter_result # type: ignore[no-any-return]
736736

737-
_LOGGER.debug(
738-
f"CPU memory usage before network construction: {get_cpu_memory_usage()} MB"
739-
)
740737
self._construct_trt_network_def()
741-
_LOGGER.debug(
738+
_LOGGER.info(
742739
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
743740
)
744741

@@ -758,16 +755,16 @@ def run(
758755
self._create_timing_cache(
759756
builder_config, self.compilation_settings.timing_cache_path
760757
)
761-
_LOGGER.debug(
762-
f"CPU memory usage before engine building: {get_cpu_memory_usage()} MB"
763-
)
758+
764759
cuda_engine = self.builder.build_engine_with_config(
765760
self.ctx.net, builder_config
766761
)
767762
assert cuda_engine
763+
768764
_LOGGER.debug(
769765
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
770766
)
767+
771768
_LOGGER.info(
772769
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
773770
)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import io
34
import logging
45
from typing import Any, List, Optional, Sequence
56

@@ -33,7 +34,7 @@ def infer_module_output_dtypes(
3334
"""
3435
outputs = [node for node in module.graph.nodes if node.op == "output"]
3536
outputs = outputs[0].args
36-
return get_output_dtypes(outputs, truncate_double)
37+
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
3738

3839

3940
def interpret_module_to_result(
@@ -113,11 +114,16 @@ def convert_module(
113114
delattr(module, attr)
114115
release_memory()
115116
logger.debug(
116-
f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB"
117+
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
117118
)
118119

120+
serialized_engine = interpreter_result.engine.serialize()
121+
with io.BytesIO() as engine_bytes:
122+
engine_bytes.write(serialized_engine)
123+
serialized_engine = engine_bytes.getvalue()
124+
breakpoint()
119125
return rt_cls(
120-
cuda_engine=interpreter_result.engine,
126+
serialized_engine=serialized_engine,
121127
input_binding_names=list(interpreter_result.input_names),
122128
output_binding_names=list(interpreter_result.output_names),
123129
name=name,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123123

124124
def __init__(
125125
self,
126-
cuda_engine: trt.ICudaEngine = None,
127126
serialized_engine: Optional[bytes] = None,
128127
input_binding_names: Optional[List[str]] = None,
129128
output_binding_names: Optional[List[str]] = None,
@@ -183,19 +182,7 @@ def __init__(
183182
# Unused currently - to be used by Dynamic Shape support implementation
184183
self.memory_pool = None
185184

186-
if cuda_engine:
187-
assert isinstance(
188-
cuda_engine, trt.ICudaEngine
189-
), "Cuda engine must be a trt.ICudaEngine object"
190-
self.engine = cuda_engine
191-
elif serialized_engine:
192-
assert isinstance(
193-
serialized_engine, bytes
194-
), "Serialized engine must be a bytes object"
195-
self.engine = serialized_engine
196-
else:
197-
raise ValueError("Serialized engine or cuda engine must be provided")
198-
185+
self.serialized_engine = serialized_engine
199186
self.input_names = (
200187
input_binding_names if input_binding_names is not None else []
201188
)
@@ -217,6 +204,7 @@ def __init__(
217204
else False
218205
)
219206
self.settings = settings
207+
self.engine = None
220208
self.weight_name_map = weight_name_map
221209
self.target_platform = Platform.current_platform()
222210
self.runtime_states = TorchTRTRuntimeStates(
@@ -231,7 +219,7 @@ def __init__(
231219
self.output_allocator: Optional[DynamicOutputAllocator] = None
232220
self.use_output_allocator_outputs = False
233221

234-
if self.engine and not self.settings.lazy_engine_init:
222+
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
235223
self.setup_engine()
236224

237225
def get_streamable_device_memory_budget(self) -> Any:
@@ -272,22 +260,13 @@ def set_default_device_memory_budget(self) -> int:
272260
return self._set_device_memory_budget(budget_bytes)
273261

274262
def setup_engine(self) -> None:
275-
276-
if isinstance(self.engine, trt.ICudaEngine):
277-
pass
278-
elif isinstance(self.engine, bytes):
279-
runtime = trt.Runtime(TRT_LOGGER)
280-
self.engine = runtime.deserialize_cuda_engine(self.engine)
281-
else:
282-
raise ValueError(
283-
"Expected engine as trt.ICudaEngine or serialized engine as bytes"
284-
)
285-
286263
assert (
287264
self.target_platform == Platform.current_platform()
288265
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
289266

290267
self.initialized = True
268+
runtime = trt.Runtime(TRT_LOGGER)
269+
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
291270
if self.settings.enable_weight_streaming:
292271
self.set_default_device_memory_budget()
293272
self.context = self.engine.create_execution_context()
@@ -323,7 +302,7 @@ def _check_initialized(self) -> None:
323302
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
324303

325304
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
326-
state_dict[prefix + "engine"] = self.engine
305+
state_dict[prefix + "engine"] = self.serialized_engine
327306
state_dict[prefix + "input_names"] = self.input_names
328307
state_dict[prefix + "output_names"] = self.output_names
329308
state_dict[prefix + "platform"] = self.target_platform
@@ -338,7 +317,7 @@ def _load_from_state_dict(
338317
unexpected_keys: Any,
339318
error_msgs: Any,
340319
) -> None:
341-
self.engine = state_dict[prefix + "engine"]
320+
self.serialized_engine = state_dict[prefix + "engine"]
342321
self.input_names = state_dict[prefix + "input_names"]
343322
self.output_names = state_dict[prefix + "output_names"]
344323
self.target_platform = state_dict[prefix + "platform"]

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
import base64
44
import copy
5-
import io
65
import logging
76
import pickle
87
from typing import Any, List, Optional, Tuple, Union
98

10-
import tensorrt as trt
119
import torch
1210
from torch_tensorrt._Device import Device
1311
from torch_tensorrt._enums import Platform
@@ -78,7 +76,6 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
7876

7977
def __init__(
8078
self,
81-
cuda_engine: Optional[trt.ICudaEngine | bytes] = None,
8279
serialized_engine: Optional[bytes] = None,
8380
input_binding_names: Optional[List[str]] = None,
8481
output_binding_names: Optional[List[str]] = None,
@@ -126,22 +123,8 @@ def __init__(
126123
"""
127124
super(TorchTensorRTModule, self).__init__()
128125

129-
if serialized_engine:
130-
assert isinstance(
131-
serialized_engine, bytes
132-
), "Serialized engine must be a bytes object"
133-
self.serialized_engine = serialized_engine
134-
135-
elif cuda_engine:
136-
assert isinstance(
137-
cuda_engine, trt.ICudaEngine
138-
), "Cuda engine must be a trt.ICudaEngine object"
139-
serialized_engine = cuda_engine.serialize()
140-
with io.BytesIO() as engine_bytes:
141-
engine_bytes.write(serialized_engine) # type: ignore
142-
self.serialized_engine = engine_bytes.getvalue()
143-
else:
144-
raise ValueError("Serialized engine or cuda engine must be provided")
126+
if not isinstance(serialized_engine, bytearray):
127+
ValueError("Expected serialized engine as bytearray")
145128

146129
self.input_binding_names = (
147130
input_binding_names if input_binding_names is not None else []
@@ -153,11 +136,12 @@ def __init__(
153136
self.hardware_compatible = settings.hardware_compatible
154137
self.settings = copy.deepcopy(settings)
155138
self.weight_name_map = weight_name_map
139+
self.serialized_engine = serialized_engine
156140
self.engine = None
157141
self.requires_output_allocator = requires_output_allocator
158142

159143
if (
160-
self.serialized_engine
144+
serialized_engine
161145
and not self.settings.lazy_engine_init
162146
and not self.settings.enable_cross_compile_for_windows
163147
):

0 commit comments

Comments
 (0)