Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Fix attention on linalg for dynamic shapes #3714

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,35 +1607,23 @@ class ConvertAtenScaledDotProductAttentionOp
op.getLoc(), "expected no attention mask when isCausal is true");
}

SmallVector<OpFoldResult> maskSizes;

if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
SmallVector<int64_t> maskStatic;
SmallVector<Value> maskDyn;
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
maskStatic.push_back(queryTy.getDimSize(i));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
}

maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2));

Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
Value emptyMask = rewriter.create<tensor::EmptyOp>(
op.getLoc(), maskStatic, maskType, maskDyn);

Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
Expand Down Expand Up @@ -822,6 +823,7 @@
"SafeSoftmaxNonNoneDtypeModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
Expand Down Expand Up @@ -3154,6 +3156,7 @@
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
Expand Down Expand Up @@ -4657,6 +4660,7 @@
"ScalarImplicitIntModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule",
Expand Down
29 changes: 29 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5370,6 +5370,35 @@ def forward(self, query, key, value):
@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
)
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)


class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)


@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
)
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
Expand Down
Loading