@@ -26,7 +26,7 @@ def fake_tensorrt_execute_engine(
26
26
modes = ["opt" ]
27
27
28
28
# Get the TRTEngine class and infer output shapes based on input shapes
29
- trt_engine = fake_trt_engine .wrapped_obj . engine
29
+ trt_engine = fake_trt_engine .real_obj
30
30
outputs_mode_dict = defaultdict (list )
31
31
for mode in modes :
32
32
input_shapes = [unwrap_tensor_shape (input , mode = mode ) for input in inputs ]
@@ -79,7 +79,21 @@ def fake_tensorrt_execute_engine(
79
79
@torch ._library .register_fake_class ("tensorrt::Engine" )
80
80
class FakeTRTEngine :
81
81
def __init__ (self , engine_info : List [str ]) -> None :
82
- self .engine = torch .classes .tensorrt .Engine (engine_info )
82
+ self .version = engine_info [torch .ops .tensorrt .ABI_TARGET_IDX ()]
83
+ self .name = engine_info [torch .ops .tensorrt .NAME_IDX ()]
84
+ self .device_info = engine_info [torch .ops .tensorrt .DEVICE_IDX ()]
85
+ self .serialized_engine = engine_info [torch .ops .tensorrt .ENGINE_IDX ()]
86
+ self .in_binding_names = engine_info [
87
+ torch .ops .tensorrt .INPUT_BINDING_NAMES_IDX ()
88
+ ]
89
+ self .out_binding_names = engine_info [
90
+ torch .ops .tensorrt .OUTPUT_BINDING_NAMES_IDX ()
91
+ ]
92
+ self .hardware_compatible = engine_info [torch .ops .tensorrt .HW_COMPATIBLE_IDX ()]
93
+ self .serialized_metadata = engine_info [
94
+ torch .ops .tensorrt .SERIALIZED_METADATA_IDX ()
95
+ ]
96
+ self .target_platform = engine_info [torch .ops .tensorrt .TARGET_PLATFORM_IDX ()]
83
97
84
98
@classmethod
85
99
def __obj_unflatten__ (cls , flattened_tq : Any ) -> Any :
@@ -127,3 +141,6 @@ def infer_outputs(self, input_shapes: List[Any]) -> Any:
127
141
128
142
def __setstate__ (self , serialized_state : List [str ]) -> Any :
129
143
pass
144
+
145
+ def __getstate__ (self ) -> Any :
146
+ pass
0 commit comments