Skip to content

Commit 2f11ce5

Browse files
authored
[mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (#66638)
This extends `vector.constant_mask` so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations): ``` // All true scalable mask %mask = vector.constant_mask [8] : vector<[8]xi1> // All false scalable mask %mask = vector.constant_mask [0] : vector<[8]xi1> // First two scalable rows %mask = vector.constant_mask [2,4] : vector<4x[4]xi1> ```
1 parent e3b1662 commit 2f11ce5

File tree

6 files changed

+83
-43
lines changed

6 files changed

+83
-43
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
22482248
define a hyper-rectangular region within which elements values are set to 1
22492249
(otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
22502250
be non-negative and not greater than the size of the corresponding vector
2251-
dimension (as opposed to vector.create_mask which allows this).
2251+
dimension (as opposed to vector.create_mask which allows this). Sizes that
2252+
correspond to scalable dimensions are implicitly multiplied by vscale,
2253+
though currently only zero (none set) or the size of the dim/vscale
2254+
(all set) are supported.
22522255

22532256
Example:
22542257

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
53205320
// Verify that each array attr element is in bounds of corresponding vector
53215321
// result dimension size.
53225322
auto resultShape = resultType.getShape();
5323+
auto resultScalableDims = resultType.getScalableDims();
53235324
SmallVector<int64_t, 4> maskDimSizes;
5324-
for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
5325-
int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
5326-
if (attrValue < 0 || attrValue > resultShape[it.index()])
5325+
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5326+
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5327+
if (maskDimSize < 0 || maskDimSize > resultShape[index])
53275328
return emitOpError(
53285329
"array attr of size out of bounds of vector result dimension size");
5329-
maskDimSizes.push_back(attrValue);
5330+
if (resultScalableDims[index] && maskDimSize != 0 &&
5331+
maskDimSize != resultShape[index])
5332+
return emitOpError(
5333+
"only supports 'none set' or 'all set' scalable dimensions");
5334+
maskDimSizes.push_back(maskDimSize);
53305335
}
53315336
// Verify that if one mask dim size is zero, they all should be zero (because
53325337
// the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
53355340
if (anyZeros && !allZeros)
53365341
return emitOpError("expected all mask dim sizes to be zeros, "
53375342
"as a result of conjunction with zero mask dim");
5338-
// Verify that if the mask type is scalable, dimensions should be zero because
5339-
// constant scalable masks can only be defined for the "none set" or "all set"
5340-
// cases, and there is no VLA way to define an "all set" case for
5341-
// `vector.constant_mask`. In the future, a convention could be established
5342-
// to decide if a specific dimension value could be considered as "all set".
5343-
if (resultType.isScalable() &&
5344-
llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
5345-
return emitOpError("expected mask dim sizes for scalable masks to be 0");
53465343
return success();
53475344
}
53485345

mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
105105
PatternRewriter &rewriter) const override {
106106
auto loc = op.getLoc();
107107
auto dstType = op.getType();
108-
auto eltType = dstType.getElementType();
109108
auto dimSizes = op.getMaskDimSizes();
110109
int64_t rank = dstType.getRank();
111110

@@ -115,43 +114,43 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
115114
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
116115
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
117116
op, dstType,
118-
DenseIntElementsAttr::get(
119-
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
120-
ArrayRef<bool>{value}));
117+
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
118+
value));
121119
return success();
122120
}
123121

124-
// Scalable constant masks can only be lowered for the "none set" case.
125-
if (cast<VectorType>(dstType).isScalable()) {
126-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
127-
op, DenseElementsAttr::get(dstType, false));
128-
return success();
129-
}
130-
131-
int64_t trueDim = std::min(dstType.getDimSize(0),
132-
cast<IntegerAttr>(dimSizes[0]).getInt());
122+
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
133123

134124
if (rank == 1) {
135-
// Express constant 1-D case in explicit vector form:
136-
// [T,..,T,F,..,F].
137-
SmallVector<bool> values(dstType.getDimSize(0));
138-
for (int64_t d = 0; d < trueDim; d++)
139-
values[d] = true;
140-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
141-
op, dstType, rewriter.getBoolVectorAttr(values));
125+
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
126+
// Use constant splat for 'all set' or 'none set' dims.
127+
// This produces correct code for scalable dimensions (it will lower to
128+
// a constant splat).
129+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
130+
op, DenseElementsAttr::get(dstType, trueDimSize != 0));
131+
} else {
132+
// Express constant 1-D case in explicit vector form:
133+
// [T,..,T,F,..,F].
134+
// Note: The verifier would reject this case for scalable vectors.
135+
SmallVector<bool> values(dstType.getDimSize(0), false);
136+
for (int64_t d = 0; d < trueDimSize; d++)
137+
values[d] = true;
138+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
139+
op, dstType, rewriter.getBoolVectorAttr(values));
140+
}
142141
return success();
143142
}
144143

145-
VectorType lowType =
146-
VectorType::get(dstType.getShape().drop_front(), eltType);
147-
SmallVector<int64_t> newDimSizes;
148-
for (int64_t r = 1; r < rank; r++)
149-
newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
144+
if (dstType.getScalableDims().front())
145+
return rewriter.notifyMatchFailure(
146+
op, "Cannot unroll leading scalable dim in dstType");
147+
148+
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
150149
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
151-
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
150+
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
152151
Value result = rewriter.create<arith::ConstantOp>(
153152
loc, dstType, rewriter.getZeroAttr(dstType));
154-
for (int64_t d = 0; d < trueDim; d++)
153+
for (int64_t d = 0; d < trueDimSize; d++)
155154
result =
156155
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
157156
rewriter.replaceOp(op, result);

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,16 +1819,53 @@ func.func @genbool_1d() -> vector<8xi1> {
18191819

18201820
// -----
18211821

1822-
func.func @genbool_1d_scalable() -> vector<[8]xi1> {
1822+
func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> {
18231823
%0 = vector.constant_mask [0] : vector<[8]xi1>
18241824
return %0 : vector<[8]xi1>
18251825
}
1826-
// CHECK-LABEL: func @genbool_1d_scalable
1826+
// CHECK-LABEL: func @genbool_1d_scalable_all_false
18271827
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
18281828
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
18291829

18301830
// -----
18311831

1832+
func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> {
1833+
%0 = vector.constant_mask [8] : vector<[8]xi1>
1834+
return %0 : vector<[8]xi1>
1835+
}
1836+
// CHECK-LABEL: func @genbool_1d_scalable_all_true
1837+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
1838+
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
1839+
1840+
// -----
1841+
1842+
func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> {
1843+
%0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
1844+
return %0 : vector<4x[4]xi1>
1845+
}
1846+
// CHECK-LABEL: func.func @genbool_2d_trailing_scalable
1847+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
1848+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
1849+
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
1850+
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
1851+
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
1852+
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
1853+
// CHECK: return %[[VAL_5]] : vector<4x[4]xi1>
1854+
1855+
// -----
1856+
1857+
/// Currently, this is not supported as generating the mask would require
1858+
/// unrolling the leading scalable dimension at compile time.
1859+
func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
1860+
%0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
1861+
return %0 : vector<[4]x4xi1>
1862+
}
1863+
// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable
1864+
// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
1865+
// CHECK: return %[[VAL_0]] : vector<[4]x4xi1>
1866+
1867+
// -----
1868+
18321869
func.func @genbool_2d() -> vector<4x4xi1> {
18331870
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
18341871
return %v: vector<4x4xi1>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
995995
// -----
996996

997997
func.func @constant_mask_scalable_non_zero_dim_size() {
998-
// expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
998+
// expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}}
999999
%0 = vector.constant_mask [2] : vector<[8]xi1>
10001000
}
10011001

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
448448
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
449449
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
450450
%1 = vector.constant_mask [0] : vector<[4]xi1>
451+
// CHECK: vector.constant_mask [4] : vector<[4]xi1>
452+
%2 = vector.constant_mask [4] : vector<[4]xi1>
453+
// CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
454+
%3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
451455
return
452456
}
453457

0 commit comments

Comments
 (0)