Skip to content

🐛 [Bug] Unable to compile the model using torch tensorrt #1565

Open
@IamExperimenting

Description

@IamExperimenting

Bug Description

Hi team, I have built the object detection model using torchvision fasterrcnn model. I need to deploy this model in Nvidia Triton server, so I’m trying to compile the model using torch_tensorrt but its failing.

@narendasan @gs-olive

To Reproduce

Steps to reproduce the behavior:

import torch, tensorrt, torch_tensorrt,torchvision
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn().eval()

trt_module = torch_tensorrt.compile(model,
    inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
    enabled_precisions = {torch.half} # Run with FP16
)
# save the TensorRT embedded Torchscript
torch.jit.save(trt_module, "trt_torchscript_module.ts")

Expected behavior

pytorch model should be compiled using torch_tensorrt library

Environment

OS : ubuntu 20.04
Python : 3.10.8

Build information about Torch-TensorRT can be found by turning on debug messages

tensorrt version : 8.5.2.2

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1
  • CPU Architecture:
  • OS (e.g., Linux): Linux - ubuntu 20.4
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.8
  • CUDA version: 11.6
  • GPU models and configuration: no
  • Any other relevant information:

Additional context

** please find the error message below **

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 1
----> 1 trt_module = torch_tensorrt.compile(model,
      2     inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
      3     enabled_precisions = {torch.half} # Run with FP16
      4 )
      5 # save the TensorRT embedded Torchscript
      6 torch.jit.save(trt_module, "trt_torchscript_module.ts")

File ~/miniconda3/envs/tensorrt/lib/python3.10/site-packages/torch_tensorrt/_compile.py:125, in compile(module, ir, inputs, enabled_precisions, **kwargs)
    120         logging.log(
    121             logging.Level.Info,
    122             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
    123         )
    124         ts_mod = torch.jit.script(module)
--> 125     return torch_tensorrt.ts.compile(
    126         ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    127     )
    128 elif target_ir == _IRType.fx:
    129     if (
    130         torch.float16 in enabled_precisions
    131         or torch_tensorrt.dtype.half in enabled_precisions
    132     ):
...

RuntimeError: 
temporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name="_has_warned"](%self, %self.backbone.body.1.use_res_conne

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions