Skip to content

🐛 [Bug] _Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing._ when trying to partially compile TransformerEncoder #1756

Closed
@narendasan

Description

@narendasan

Bug Description

When trying to partially compile transformer encoder, you are met with the following error:

Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn as nn

import torch_tensorrt

class TransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
        super(TransformerModel, self).__init__()

        # define embedding layer
        self.embedding = nn.Embedding(input_dim, hidden_dim)

        # define transformer encoder
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads),
            num_layers=num_layers
        )

        # define output layer
        self.fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # apply embedding layer
        x = self.embedding(x)

        # apply transformer encoder
        x = self.transformer_encoder(x)

        # apply output layer
        x = self.fc(x)

        return x


model = TransformerModel(input_dim=100, hidden_dim=128, num_layers=2, num_heads=4)

input_data = torch.randint(low=0, high=100, size=(32,10)) # sequence length of 10, batch size of 32

input_data = input_data.to("cuda").to(torch.int)
model.to("cuda")
output = model(input_data)
model.eval()

inputs = [
    torch_tensorrt.Input(
        min_shape=[32,10],
        opt_shape=[32,10],
        max_shape=[32,10],
        dtype=torch.int,
    )]

enabled_precisions = {torch.float, torch.half}  # Run with fp16


with torch_tensorrt.logging.graphs():
    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions, require_full_compilation=True
    )

result = trt_ts_module(input_data)

with open("../saved_models/trt_ts_module.ts", "wb") as f:
    torch.jit.save(trt_ts_module, f)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): main
  • PyTorch Version (e.g. 1.0): 2.0
  • CPU Architecture: x64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): source
  • Build command you used (if compiling from source): pip install -e py
  • Are you using local sources or building from archives: archives
  • Python version: 3.9
  • CUDA version: 11.8
  • GPU models and configuration: 3080Ti
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions