Description
Your current environment
Error information (Sorry, I cannot disclose more due to confidentiality reasons)
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (aarch64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35
Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-118-generic-aarch64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mindietorch==1.0.0+torch2.1.0.abi0
[pip3] numpy==1.26.4
[pip3] pyzmq==26.2.0
[pip3] torch==2.1.0
[pip3] torch-npu==2.1.0.post10.dev20241217
[pip3] torchvision==0.16.0
[pip3] transformers==4.46.1
[pip3] tritonclient==2.49.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3.dev0+g7193774b.d20250103
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect
PYTORCH_INSTALL_PATH=/usr/local/lib/python3.10/dist-packages/torch
PYTORCH_NPU_INSTALL_PATH=/usr/local/lib/python3.10/dist-packages/torch_npu
Model Input Dumps
No response
🐛 Describe the bug
Description
We encountered the following error while running internvl
:
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
After careful analysis, we found that this was a logic bug that had nothing to do with the device we were using.
The issue occurs in the following line of code in vllm/vllm/model_executor/models/intern_vit.py
:
class InternSdpaAttention(nn.Module):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
Analysis
At the beginning of this file, there is a check to determine whether xformers
is available:
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
The execution brunch is different depending on the value of USE_XFORMERS_OPS
. The error in question occurs only when USE_XFORMERS_OPS = False
, making it hard to reproduce under normal circumstances (e.g. use Nvidia GPU in most cases).
Why the usage of .transpose() and .view() consecutively is not recommended: In the affected code, the output after the attention operation is a 4D tensor, which we will call A
. If A
is a strided tensor, according to the PyTorch documentation for transpose
, the .transpose
operation only creates a view of A
, modifying its strides. Meanwhile, the documentation for .view
specifies that the new view must be compatible with the original size and stride. This means that the new view dimensions must be a subspace of the original dimensions, or the input must satisfy the formula mentioned in the docs. Since the .transpose
operation changes the strides, the view created by .view
may not satisfy these conditions and lead to the error. The .view
documentation also suggests using .contiguous()
before .view
to ensure that the tensor is contiguous in memory.
Solution
The issue arises from the fact that the tensor A
is strided, and the strides of the tensor after .transpose
likely don't meet the conditions for .view
to work correctly. To resolve this issue, we recommend adding a .contiguous()
call before the .view
operation to ensure that the tensor is contiguous in memory.
P.S. We also found that .transpose().view() is only valid if:
A
is a sparse tensor instead of a strided tensor.- The stride after
transpose
happens to meet the compatibility conditions for.view
. - PyTorch may have undocumented methods for handling these situations, such as not requiring
.contiguous()
when the changes made by.view
and.transpose
are mirrored.
Related Issues
This issue is similar to the one mentioned in issue #8630, but the key to preventing the problem from reproducing is not found, namely the value of USE_XFORMERS_OPS
. It was addressed in PR #8880. However, since PR #9560 was believed to have solved this problem, PR #8880 was closed. We found that this issue was not actually resolved in PR #9560, so we are submitting this new issue with our analysis and an explanation based on the official PyTorch documentation.
Reproducible Example
Here is a simple code snippet to verify the issue:
import torch
x = torch.randn(2, 3, 4, 5)
y = x.transpose(1, 2)
# y = y.contiguous()
try:
z = y.view(2, 4, -1)
print(z.size())
except Exception as e:
print("Error:", e)
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.