Skip to content

Commit

Permalink
[linalg] Broadcast batch for mask on sdpa lowering (#3824)
Browse files Browse the repository at this point in the history
Attention often broadcasts a mask across the batch dimension as masking
is usually performed the same across attention heads. Added this
materialization to the mask dimensions optionally.
  • Loading branch information
rsuderman authored Nov 1, 2024
1 parent 5aa323d commit 25738b8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 25 deletions.
102 changes: 79 additions & 23 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,7 @@ class ConvertAtenScaledDotProductAttentionOp
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

auto loc = op.getLoc();
Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal();
Value scale = op.getScale();
Expand All @@ -1671,49 +1672,46 @@ class ConvertAtenScaledDotProductAttentionOp
double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
return rewriter.notifyMatchFailure(loc, "dropout not supported");

bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
if (!isa<Torch::NoneType>(mask.getType())) {
return rewriter.notifyMatchFailure(
op.getLoc(), "expected no attention mask when isCausal is true");
loc, "expected no attention mask when isCausal is true");
}

SmallVector<int64_t> maskStatic;
SmallVector<Value> maskDyn;
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
maskStatic.push_back(queryTy.getDimSize(i));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
maskDyn.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
}

maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2));
maskDyn.push_back(
rewriter.create<tensor::DimOp>(loc, key, keyTy.getRank() - 2));

Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask = rewriter.create<tensor::EmptyOp>(
op.getLoc(), maskStatic, maskType, maskDyn);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(loc, maskStatic, maskType, maskDyn);

Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(),
loc,
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));

mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
.getResult(0);
mask = rewriter.create<linalg::FillOp>(loc, zero, emptyMask).getResult(0);

int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask,
loc, mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
Expand All @@ -1727,18 +1725,78 @@ class ConvertAtenScaledDotProductAttentionOp
mask = genericOp.getResult(0);
}

// Broadcast the batch dimensions of the mask:
if (!isa<Torch::NoneType>(mask.getType())) {
auto maskTy = cast<RankedTensorType>(mask.getType());
int64_t rank = maskTy.getRank();
bool needsBroadcast = false;
for (int i = 0, s = rank - 2; i < s; ++i) {
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
}

if (needsBroadcast) {
SmallVector<int64_t> maskShape;
SmallVector<Value> maskDynDims;

SmallVector<AffineExpr> maskExprs;
for (int i = 0, s = rank - 2; i < s; ++i) {
maskShape.push_back(keyTy.getDimSize(i));

if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
maskExprs.push_back(rewriter.getAffineDimExpr(i));
}

if (keyTy.isDynamicDim(i)) {
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
}
}

maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2));
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1));
maskShape.push_back(maskTy.getDimSize(rank - 2));
maskShape.push_back(maskTy.getDimSize(rank - 1));
if (maskTy.isDynamicDim(rank - 2))
maskDynDims.push_back(
rewriter.create<tensor::DimOp>(loc, mask, rank - 2));
if (maskTy.isDynamicDim(rank - 1))
maskDynDims.push_back(
rewriter.create<tensor::DimOp>(loc, mask, rank - 1));

SmallVector<AffineMap> affineMaps = {
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs,
op.getContext()),
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<utils::IteratorType> findMaxIteratorTypes(
rank, utils::IteratorType::parallel);

Value emptyMask = rewriter.create<tensor::EmptyOp>(
loc, maskShape, maskTy.getElementType(), maskDynDims);
Value newMask =
rewriter
.create<linalg::GenericOp>(
loc, emptyMask.getType(), mask, ValueRange({emptyMask}),
affineMaps, findMaxIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
mask = newMask;
}
}

if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
scaleFloat != 1.0)
return rewriter.notifyMatchFailure(op.getLoc(),
"only default scale supported");
return rewriter.notifyMatchFailure(loc, "only default scale supported");
}
bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
isGQAEnabled)
return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported");
loc, "grouped query attention not supported");

if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
Expand All @@ -1753,7 +1811,6 @@ class ConvertAtenScaledDotProductAttentionOp
reassociation[1].push_back(valueTy.getRank() - 2);
reassociation[2].push_back(valueTy.getRank() - 1);

auto loc = op.getLoc();
auto collapseBatch = [&rewriter, &reassociation,
loc](Value value) -> Value {
auto valueTy = cast<ShapedType>(value.getType());
Expand Down Expand Up @@ -1788,13 +1845,12 @@ class ConvertAtenScaledDotProductAttentionOp
SmallVector<int64_t> valueSizes(
cast<ShapedType>(value.getType()).getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
SmallVector<Value> outSizesDynamic(
getTensorSizes(rewriter, op.getLoc(), query));
SmallVector<Value> outSizesDynamic(getTensorSizes(rewriter, loc, query));
outSizesDynamic[outSizesDynamic.size() - 1] =
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1];
Type outType = RankedTensorType::get(outSizes, elementType);
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);
Value output =
createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType);

SmallVector<Value> inputs = SmallVector<Value>{query, key, value};

Expand Down
4 changes: 2 additions & 2 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5501,7 +5501,7 @@ def __init__(self):
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.float32, True),
([2, 1, 8, 12], torch.float32, True),
]
)
def forward(self, query, key, value, mask):
Expand All @@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
mask = torch.randn(2, 1, 8, 12, dtype=torch.float32)
module.forward(query, key, value, mask)


Expand Down

0 comments on commit 25738b8

Please sign in to comment.