Description
Bug Description
When using TorchTRT torch.compile backend, the softmax in nn.MultiheadAttention
is not fused into an engine, leading to performance regression compared to PyTorch eager mode. In particular, the softmax is executed as 5 separate ops (aten:amax
, aten:sub
, aten:exp
, aten:sum
and aten:div
) with TorchTRT (no TRT engine created for these ops), while the entire softmax executes as a single op within PyTorch eager mode (aten:_softmax
). This leads to a performance regression when using TorchTRT for MHA layer.
Eager mode (single op for softmax which is highlighted): ~56us
TRT (softmax takes 5 ops outside of engine which are highlighted): ~180us
To Reproduce
Steps to reproduce the behavior:
- Profile with NSight Systems (or just run script directly):
nsys profile --trace cuda,nvtx --sample cpu --force-overwrite true --output profiling_results/tmp --gpu-metrics-device=all --gpu-metrics-frequency=20000 python torchtrt_softmax_issue.py
torchtrt_softmax_issue.py
import torch
from torch import nn
import torch_tensorrt
class MHANetwork(nn.Module):
def __init__(self):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim=1000, num_heads=1)
def forward(self, x):
x, _ = self.mha(x, x, x)
return x
"""PROFILING UTILITIES"""
def timed(fn, msg):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.nvtx.mark(f'timed start: {msg}')
start.record()
result = fn()
end.record()
torch.cuda.nvtx.mark(f'timed end: {msg}')
torch.cuda.synchronize()
return result, start.elapsed_time(end)
def profile(method, nvtx_msg=None):
# profiling
print('Profiling with torch NVTX markers enabled')
torch.cuda.cudart().cudaProfilerStart()
with torch.autograd.profiler.emit_nvtx(record_shapes=True):
if nvtx_msg:
torch.cuda.nvtx.mark(f'profile start: {nvtx_msg}')
result = method()
if nvtx_msg:
torch.cuda.nvtx.mark(f'profile end: {nvtx_msg}')
torch.cuda.cudart().cudaProfilerStop()
return result
def run_two_models(model1, model2, inp):
torch.cuda.synchronize()
output1, time1 = timed(lambda: model1(inp), f'profiling model1')
output2, time2 = timed(lambda: model2(inp), f'profiling model2')
return time1, time2
"""MAIN"""
if __name__ == '__main__':
inp = torch.rand((1000, 1000), device='cuda:0')
# Eager
eager_model = MHANetwork()
eager_model.to('cuda:0').eval()
print(f'Calling eager model')
eager_model(inp)
# TorchTRT
trt_model = torch_tensorrt.compile(eager_model, ir="torch_compile", inputs=inp, debug=True)
print(f'Calling trt_model')
trt_model(inp)
# Profiling
time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp))
print(f'Eager time: {time_eager}ms\nTime TRT: {time_trt}ms')
Output:
<TRT debug logs...>
Calling eager model
Calling trt_model
Profiling with torch NVTX markers enabled
Eager time: 1.3304959535598755ms
Time TRT: 3.206144094467163ms
Note the slower time with TRT.
Note if you are TorchTRT debug build you might get an error about profiling already being enabled. In this case you can either:
- Replace
time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp))
withtime_eager, time_trt = run_two_models(eager_model, trt_model, inp)
- Change
Line 141 in b774440
dbg
toopt
and re-build TorchTRT.
Expected behavior
I expected that the softmax would either 1) get fused into a TRT engine or 2) the aten:_softmax
op would be executed rather than the five constituent ops that form a softmax.
Environment
- Torch-TensorRT Version (e.g. 1.0.0): commit
65c6494ce3107d33b27fdb1630ac7982f8649382
built from source - PyTorch Version (e.g. 1.0):
2.1.0a0+4136153
(commit7682252cac9ed31055d2ae950cf6942dd311da73
) - CPU Architecture:
- OS (e.g., Linux): Ubuntu 22.04.2 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): source - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.10.6
- CUDA version: 12.1
- GPU models and configuration: NVIDIA A100 80GB PCIe
- Any other relevant information:
Additional context
From debug logs:
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 3
- torch.ops.aten.permute.default + Operator Count: 9
- torch.ops.aten.view.default + Operator Count: 11
- torch.ops.aten.mul.Tensor + Operator Count: 8
- torch.ops.aten.mm.default + Operator Count: 4
- torch.ops.aten.add.Tensor + Operator Count: 4
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 1
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.squeeze.dim + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 2
- _operator.getitem + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.compile:Detected support for 44 operators out of 56 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:Eliminating acc subgraph because it's smaller than the threshold: 2 < 5
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:Eliminating acc subgraph because it's smaller than the threshold: 1 < 5
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 2
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 3
- torch.ops.aten.permute.default + Operator Count: 9
- torch.ops.aten.view.default + Operator Count: 11
- torch.ops.aten.mul.Tensor + Operator Count: 8
- torch.ops.aten.mm.default + Operator Count: 4
- torch.ops.aten.add.Tensor + Operator Count: 4
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 1
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.squeeze.dim + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 2
- _operator.getitem + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
These logs suggest torch.ops.aten.amax.default
is not supported, which could explain why the softmax isn't getting fused together.