Skip to content

Commit d3a8880

Browse files
authored
fix: Fix additional mem copy of the model during re-export (#3302)
1 parent 3e376c4 commit d3a8880

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

py/torch_tensorrt/dynamo/runtime/register_fake_class.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def fake_tensorrt_execute_engine(
2626
modes = ["opt"]
2727

2828
# Get the TRTEngine class and infer output shapes based on input shapes
29-
trt_engine = fake_trt_engine.wrapped_obj.engine
29+
trt_engine = fake_trt_engine.real_obj
3030
outputs_mode_dict = defaultdict(list)
3131
for mode in modes:
3232
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
@@ -79,7 +79,21 @@ def fake_tensorrt_execute_engine(
7979
@torch._library.register_fake_class("tensorrt::Engine")
8080
class FakeTRTEngine:
8181
def __init__(self, engine_info: List[str]) -> None:
82-
self.engine = torch.classes.tensorrt.Engine(engine_info)
82+
self.version = engine_info[torch.ops.tensorrt.ABI_TARGET_IDX()]
83+
self.name = engine_info[torch.ops.tensorrt.NAME_IDX()]
84+
self.device_info = engine_info[torch.ops.tensorrt.DEVICE_IDX()]
85+
self.serialized_engine = engine_info[torch.ops.tensorrt.ENGINE_IDX()]
86+
self.in_binding_names = engine_info[
87+
torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX()
88+
]
89+
self.out_binding_names = engine_info[
90+
torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX()
91+
]
92+
self.hardware_compatible = engine_info[torch.ops.tensorrt.HW_COMPATIBLE_IDX()]
93+
self.serialized_metadata = engine_info[
94+
torch.ops.tensorrt.SERIALIZED_METADATA_IDX()
95+
]
96+
self.target_platform = engine_info[torch.ops.tensorrt.TARGET_PLATFORM_IDX()]
8397

8498
@classmethod
8599
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
@@ -127,3 +141,6 @@ def infer_outputs(self, input_shapes: List[Any]) -> Any:
127141

128142
def __setstate__(self, serialized_state: List[str]) -> Any:
129143
pass
144+
145+
def __getstate__(self) -> Any:
146+
pass

tests/py/dynamo/models/test_reexport.py

-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def forward(self, x):
106106

107107
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
108108
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
109-
torchtrt.save(trt_module, trt_ep_path)
110109

111110
# Reexport
112111
trt_exp_program = torch.export.export(trt_module, (input,), strict=False)

0 commit comments

Comments
 (0)