Skip to content

❓ [Question] Internal Error-given invalid tensor name #1844

Closed
@DanielLevi6

Description

@DanielLevi6

❓ Question

I want to convert a torch model(from python) to a runtime model(in C++), using the torch.fx capabilities. That will allow me to accelerate a model that isn't fully supported by TensorRT.
I understand that this flow is experimental, so I used the examples which are given in this repository.

By using this example-
https://github.com/pytorch/TensorRT/blob/main/examples/fx/fx2trt_example_next.py

I got some internal errors while running this code part(and also while running inference after that, but the error messages are identical as before, so I guess it's related.)-
trt_mod = TRTModule(
name="my_module",
serialized_engine=engine_str,
input_binding_names=r.input_names,
output_binding_names=r.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
)

The error messages are-
ERROR: [Torch-TensorRT] - 3: [engine.cpp::getProfileObliviousBindingIndex::1386] Error Code 3: Internal Error (getTensorShape given invalid tensor name: input_0)
ERROR: [Torch-TensorRT] - 3: [engine.cpp::getProfileObliviousBindingIndex::1386] Error Code 3: Internal Error (getTensorDataType given invalid tensor name: input_0)
ERROR: [Torch-TensorRT] - 3: [engine.cpp::getProfileObliviousBindingIndex::1386] Error Code 3: Internal Error (getTensorShape given invalid tensor name: output_0)
ERROR: [Torch-TensorRT] - 3: [engine.cpp::getProfileObliviousBindingIndex::1386] Error Code 3: Internal Error (getTensorDataType given invalid tensor name: output_0)
What can cause these errors?
I tried to find other way to define the model inputs and outputs(which will maybe affect the input and output names in some way, as hinted from the error messages), but I don't see other way in the examples.

What you have already tried

I have already tried the notebook I linked before, and on other flow I got in the torch forum-
https://discuss.pytorch.org/t/using-torchtrt-fx-backend-on-c/170639/6

The code for this flow is-
model_fx = model_fx.cuda()
inputs_fx = [i.cuda() for i in inputs_fx]
trt_fx_module_f16 = torch_tensorrt.compile(
model_fx,
ir="fx",
inputs=inputs_fx,
enabled_precisions={torch.float16},
use_experimental_fx_rt=True,
explicit_batch_dimension=True
)
torch.save(trt_fx_module_f16, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
scripted_fx_module = torch.jit.trace(trt_fx_module_f16, example_inputs=inputs_fx)
scripted_fx_module.save("/tmp/scripted_fx_module.ts")
scripted_fx_module = torch.jit.load("/tmp/scripted_fx_module.ts") #This can also be loaded in C++

The error is the same, while running the torch.compile method, using the "use_fx_experimental_rt=True" flag

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 1.13.1
  • CPU Architecture: x86-64
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): -
  • Are you using local sources or building from archives: I used the pre-built version of Torch-TensorRT 1.3.0 release
  • Python version: 3.8.10
  • CUDA version: 11.8
  • GPU models and configuration: NVIDIA T1000
  • Any other relevant information: -

Metadata

Metadata

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions