Skip to content

🐛 [Bug] Error when compiling aten model with intermediate int64 tensors #1864

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the T5-Base Model model via the aten path, the following error is encountered:

  File "~/TensorRT/py/torch_tensorrt/fx/fx2trt.py", line 303, in placeholder
    name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
  File "~/TensorRT/py/torch_tensorrt/fx/utils.py", line 45, in torch_dtype_to_trt
    raise TypeError("%s is not supported by tensorrt" % dtype)
TypeError: torch.int64 is not supported by tensorrt

Since none of the input tensors have type int64, it is presumed that some intermediate tensor encountered during partitioning takes an int64 input, which are generally associated with indices in Torch.

To Reproduce

Steps to reproduce the behavior:

  1. Initialize model: T5Model.from_pretrained("t5-base").eval().cuda()
  2. Initialize three input tensors, for example: torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") ("input_ids", "attention_mask", "decoder_input_ids")
  3. (Optional) Use the transformers tools to trace the model via: transformers.utils.fx.symbolic_trace(model, input_names=["input_ids", "attention_mask", "decoder_input_ids"])
  4. Compile the model using FX

Expected behavior

Model should compile via the aten path

Environment

  • Transformers: 4.26.1
  • Torch-TensorRT Version (e.g. 1.0.0): b3f433a
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230419+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions