Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] How do you properly deploy a quantized model with tensorrt #3267

Open
Urania880519 opened this issue Oct 29, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@Urania880519
Copy link

Urania880519 commented Oct 29, 2024

❓ Question

I have a PTQ model and a QAT model trained with the official pytorch API following the quantization tutorial, and I wish to deploy them on TensorRT for inference. The model is metaformer-like using convolution layers as token mixer. One part of the quantized model looks like this:
image

What you have already tried

I have tried different ways to make things work:

  1. the package torch2trt: there's huge problem with dynamic input. The dataset consists of different inputs (B,C,H,W) where H and W are not necessarily the same. There's a torch2trt-dynamic package but I think there are bugs in the plugins. The code basically looks like this:
    model_trt = torch2trt( model_fp32, [torch.randn(1, 11, 64, 64).to('cuda')], max_batch_size=batch_size, fp16_mode=False, int8_mode=True, calibrator= trainLoader, input_shapes=[(None, 11, None, None)] )
  2. torch.compile() with backends=tensorrt. When I was trying to compile the PTQ model, there's RuntimeError: quantized::conv2d (ONEDNN): data type of input should be QUint8. And when I was trying to use the QAT model, there's W1029 14:21:17.640402 139903289382080 torch/_dynamo/utils.py:1195] [2/0] Unsupported: quantized nyi in meta tensors with fake tensor propagation.
    Here's the code I used:
    trt_gm = torch.compile( model, dynamic= True, backend="tensorrt",)
  3. try to convert the torch model to an onnx model, then convert it into the trt engine. There are several problems in this case:
  • The onnx model is runs weirdly slow with onnx runtime. Furthermore, the loss calculated is extremely high. Here's an example:
    image

  • I tried to visualize the quantized ONNX model with Netron because converting the quantized ONNX model to TRT engine always raise
    image
    This is the problematic part of the graph
    image
    The rightmost DequantizeLinear node is causing problem. I checked the x and found that it's an in32 constant array and the x_scale is a float32 constant array. The output of this node turned out to be the bias passed into the Conv layer.
    There must be something wrong in the behavior of the conversion. When doing quantization with the pytorch API, only activations and weights were observed by the defined observer, so I was expecting only the leftmost and the middle DequantizeLinear Nodes while bias should be stored in fp32 and directly passed into the Conv layer. Using onnx_simplified is not able to get rid of the node. With the incompatibility between the conversion of quantized torch model to ONNX model, I'm not able to further convert the model into trt engine. I've considered using the onnx API for quantization, but the performance drop thing from unquantized original torch model to ONNX model is quite concerning.
    The converting code looks like this:
    torch.onnx.export( quantized_model, dummy_input, args.onnx_export_path, input_names=["input"], output_names=["output"], opset_version=13, export_params= True, keep_initializers_as_inputs=False, dynamic_axes= {'input': {0:'batch_size', 2: "h", 3: "w"}, 'output': {0:'batch_size', 2: "h", 3: "w"} } )

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version: 2.3.1
  • CPU Architecture: x86_64
  • OS: Ubuntu 20.04.4 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Are you using local sources or building from archives: No
  • Python version: 3.9.19
  • CUDA version: 12.1
  • GPU models and configuration:
  • Torch_TensorRT: 2.3.0
  • torch2trt: 0.5.0
  • onnx:1.16.1

Additional context

Personally I think the torch.compile() API is the most possible for me to successfully convert the quantized model since there's no performance drop. Does anyone has relevant experience on handling quantized model?

@Urania880519 Urania880519 added the question Further information is requested label Oct 29, 2024
@narendasan
Copy link
Collaborator

Did you follow this tutorial? https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_ptq.html

@Urania880519
Copy link
Author

Urania880519 commented Oct 30, 2024

@narendasan
I've followed both the tutorial you provided and this one: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html#dynamic-shapes
However, there's this error after finishing calibration(the calibration seemed successful and the loss was quite low)
image
image
This is the code I used:

  quant_cfg = mtq.INT8_DEFAULT_CFG
  mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
  with torch.no_grad():
      with export_torch_mode():
          input_tensor = torch.randn((1, channels, 35, 35), dtype=torch.float32).to('cuda')
          height_dim = torch.export.Dim("height_dim", min=25, max=64)
          width_dim= torch.export.Dim("width_dim", min=25, max=64)
          dynamic_shapes = ({2: height_dim, 3: width_dim},)
          from torch.export._trace import _export
          exp_program = _export(model, (input_tensor,), dynamic_shapes= dynamic_shapes)
          trt_Qmodel = torchtrt.dynamo.compile(
                  exp_program,
                  inputs=[input_tensor],
                  enabled_precisions={torch.int8},
                  min_block_size=1,
                  debug=False,
                  assume_dynamic_shape_support= True
           )

@narendasan
Copy link
Collaborator

@lanluo-nvidia or @peri044 can you provide additional guidance here?

@lanluo-nvidia lanluo-nvidia self-assigned this Nov 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants