Skip to content

🐛 [Bug] Softmax in nn.MultiheadAttention layer not fused with torch.compile backend #2267

Closed
@austinapatel

Description

@austinapatel

Bug Description

When using TorchTRT torch.compile backend, the softmax in nn.MultiheadAttention is not fused into an engine, leading to performance regression compared to PyTorch eager mode. In particular, the softmax is executed as 5 separate ops (aten:amax, aten:sub, aten:exp, aten:sum and aten:div) with TorchTRT (no TRT engine created for these ops), while the entire softmax executes as a single op within PyTorch eager mode (aten:_softmax). This leads to a performance regression when using TorchTRT for MHA layer.

Eager mode (single op for softmax which is highlighted): ~56us
softmax_eager

TRT (softmax takes 5 ops outside of engine which are highlighted): ~180us
softmax_trt

To Reproduce

Steps to reproduce the behavior:

  1. Profile with NSight Systems (or just run script directly): nsys profile --trace cuda,nvtx --sample cpu --force-overwrite true --output profiling_results/tmp --gpu-metrics-device=all --gpu-metrics-frequency=20000 python torchtrt_softmax_issue.py

torchtrt_softmax_issue.py

import torch
from torch import nn
import torch_tensorrt

class MHANetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=1000, num_heads=1)

    def forward(self, x):
        x, _ = self.mha(x, x, x)
        return x


"""PROFILING UTILITIES"""
def timed(fn, msg):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.nvtx.mark(f'timed start: {msg}')
    start.record()
    result = fn()
    end.record()
    torch.cuda.nvtx.mark(f'timed end: {msg}')
    torch.cuda.synchronize()
    return result, start.elapsed_time(end)


def profile(method, nvtx_msg=None):
    # profiling
    print('Profiling with torch NVTX markers enabled')
    torch.cuda.cudart().cudaProfilerStart()
    with torch.autograd.profiler.emit_nvtx(record_shapes=True):
        if nvtx_msg:
            torch.cuda.nvtx.mark(f'profile start: {nvtx_msg}')
        result = method()
        if nvtx_msg:
            torch.cuda.nvtx.mark(f'profile end: {nvtx_msg}')
    torch.cuda.cudart().cudaProfilerStop()

    return result


def run_two_models(model1, model2, inp):
    torch.cuda.synchronize()

    output1, time1 = timed(lambda: model1(inp), f'profiling model1')
    output2, time2 = timed(lambda: model2(inp), f'profiling model2')

    return time1, time2

"""MAIN"""
if __name__ == '__main__':
    inp = torch.rand((1000, 1000), device='cuda:0')

    # Eager
    eager_model = MHANetwork()
    eager_model.to('cuda:0').eval()
    print(f'Calling eager model')
    eager_model(inp)

    # TorchTRT
    trt_model = torch_tensorrt.compile(eager_model, ir="torch_compile", inputs=inp, debug=True)
    print(f'Calling trt_model')
    trt_model(inp)

    # Profiling
    time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp))
    print(f'Eager time: {time_eager}ms\nTime TRT: {time_trt}ms')

Output:

<TRT debug logs...>
Calling eager model
Calling trt_model
Profiling with torch NVTX markers enabled
Eager time: 1.3304959535598755ms
Time TRT: 3.206144094467163ms

Note the slower time with TRT.

Note if you are TorchTRT debug build you might get an error about profiling already being enabled. In this case you can either:

  1. Replace time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp)) with time_eager, time_trt = run_two_models(eager_model, trt_model, inp)
  2. Change
    cmd.append("--compilation_mode=dbg")
    from dbg to opt and re-build TorchTRT.

Expected behavior

I expected that the softmax would either 1) get fused into a TRT engine or 2) the aten:_softmax op would be executed rather than the five constituent ops that form a softmax.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): commit 65c6494ce3107d33b27fdb1630ac7982f8649382 built from source
  • PyTorch Version (e.g. 1.0): 2.1.0a0+4136153 (commit 7682252cac9ed31055d2ae950cf6942dd311da73)
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu 22.04.2 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): source
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.6
  • CUDA version: 12.1
  • GPU models and configuration: NVIDIA A100 80GB PCIe
  • Any other relevant information:

Additional context

From debug logs:

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 3
- torch.ops.aten.permute.default + Operator Count: 9
- torch.ops.aten.view.default + Operator Count: 11
- torch.ops.aten.mul.Tensor + Operator Count: 8
- torch.ops.aten.mm.default + Operator Count: 4
- torch.ops.aten.add.Tensor + Operator Count: 4
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 1
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.squeeze.dim + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 2
- _operator.getitem + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.compile:Detected support for 44 operators out of 56 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:Eliminating acc subgraph because it's smaller than the threshold: 2 < 5
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:Eliminating acc subgraph because it's smaller than the threshold: 1 < 5
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 2
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 3
- torch.ops.aten.permute.default + Operator Count: 9
- torch.ops.aten.view.default + Operator Count: 11
- torch.ops.aten.mul.Tensor + Operator Count: 8
- torch.ops.aten.mm.default + Operator Count: 4
- torch.ops.aten.add.Tensor + Operator Count: 4
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 1
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.squeeze.dim + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 2
- _operator.getitem + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1

These logs suggest torch.ops.aten.amax.default is not supported, which could explain why the softmax isn't getting fused together.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions