Skip to content

Commit 5f26497

Browse files
authored
[mlir][vector] Use DenseI64ArrayAttr in vector.multi_reduction (#102637)
This prevents some unnecessary conversions to/from int64_t and IntegerAttr.
1 parent 2849ebb commit 5f26497

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
286286
Arguments<(ins Vector_CombiningKindAttr:$kind,
287287
AnyVector:$source,
288288
AnyType:$acc,
289-
I64ArrayAttr:$reduction_dims)>,
289+
DenseI64ArrayAttr:$reduction_dims)>,
290290
Results<(outs AnyType:$dest)> {
291291
let summary = "Multi-dimensional reduction operation";
292292
let description = [{
@@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :
325325

326326
SmallVector<bool> getReductionMask() {
327327
SmallVector<bool> res(getSourceVectorType().getRank(), false);
328-
for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
329-
res[ia.getInt()] = true;
328+
for (int64_t dim : getReductionDims())
329+
res[dim] = true;
330330
return res;
331331
}
332332
static SmallVector<bool> getReductionMask(

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
445445
for (const auto &en : llvm::enumerate(reductionMask))
446446
if (en.value())
447447
reductionDims.push_back(en.index());
448-
build(builder, result, kind, source, acc,
449-
builder.getI64ArrayAttr(reductionDims));
448+
build(builder, result, kind, source, acc, reductionDims);
450449
}
451450

452451
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
@@ -466,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() {
466465
SmallVector<bool> scalableDims;
467466
Type inferredReturnType;
468467
auto sourceScalableDims = getSourceVectorType().getScalableDims();
469-
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
470-
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
471-
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
472-
})) {
473-
targetShape.push_back(it.value());
474-
scalableDims.push_back(sourceScalableDims[it.index()]);
468+
for (auto [dimIdx, dimSize] :
469+
llvm::enumerate(getSourceVectorType().getShape()))
470+
if (!llvm::any_of(getReductionDims(),
471+
[dimIdx = dimIdx](int64_t reductionDimIdx) {
472+
return reductionDimIdx == static_cast<int64_t>(dimIdx);
473+
})) {
474+
targetShape.push_back(dimSize);
475+
scalableDims.push_back(sourceScalableDims[dimIdx]);
475476
}
476477
// TODO: update to also allow 0-d vectors when available.
477478
if (targetShape.empty())

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
6767
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
6868

6969
// Separate reduction and parallel dims
70-
auto reductionDimsRange =
71-
multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
72-
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
73-
reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
70+
ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
7471
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
7572
reductionDims.end());
7673
int64_t reductionSize = reductionDims.size();

0 commit comments

Comments
 (0)