Skip to content

Commit 4084038

Browse files
committed
[mlir][vector] Restrict vector.insert/vector.extract
This patch restricts the use of vector.insert and vector.extract Ops in the Vector dialect. Specifically: * The non-indexed operands for `vector.insert` and `vector.extract` must now be non-0-D vectors. The following are now illegal. Note that the source and result types (i.e. non-indexed args) are rank-0 vectors: ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` Put differently, this PR removes the ambiguity when it comes to non-indexed operands of `vector.insert` and `vector.extract`. By requiring that only one form is used, it eliminates the flexibility of allowing both, thereby simplifying the semantics. For more context, see the related RFC: * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
1 parent 082598a commit 4084038

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,10 @@ struct UnrollTransferReadConversion
12951295

12961296
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12971297
/// accesses, and broadcasts and transposes in permutation maps.
1298+
///
1299+
/// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
1300+
/// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
1301+
/// MemRef and Tensor source, respectively).
12981302
LogicalResult matchAndRewrite(TransferReadOp xferOp,
12991303
PatternRewriter &rewriter) const override {
13001304
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1325,6 +1329,8 @@ struct UnrollTransferReadConversion
13251329
for (int64_t i = 0; i < dimSize; ++i) {
13261330
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
13271331

1332+
// FIXME: Rename this lambda - it does much more than just
1333+
// in-bounds-check generation.
13281334
vec = generateInBoundsCheck(
13291335
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
13301336
/*inBoundsCase=*/
@@ -1339,12 +1345,34 @@ struct UnrollTransferReadConversion
13391345
insertionIndices.push_back(rewriter.getIndexAttr(i));
13401346

13411347
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1342-
auto newXferOp = b.create<vector::TransferReadOp>(
1343-
loc, newXferVecType, xferOp.getSource(), xferIndices,
1344-
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1345-
xferOp.getPadding(), Value(), inBoundsAttr);
1346-
maybeAssignMask(b, xferOp, newXferOp, i);
1347-
return b.create<vector::InsertOp>(loc, newXferOp, vec,
1348+
1349+
// A value that's read after rank-reducing the original
1350+
// vector.transfer_read Op.
1351+
Value unpackedReadRes;
1352+
if (newXferVecType.getRank() != 0) {
1353+
// Unpacking Vector that's rank > 2
1354+
// (use vector.transfer_read to load a rank-reduced vector)
1355+
unpackedReadRes = b.create<vector::TransferReadOp>(
1356+
loc, newXferVecType, xferOp.getSource(), xferIndices,
1357+
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1358+
xferOp.getPadding(), Value(), inBoundsAttr);
1359+
maybeAssignMask(b, xferOp,
1360+
dyn_cast<vector::TransferReadOp>(
1361+
unpackedReadRes.getDefiningOp()),
1362+
i);
1363+
} else {
1364+
// Unpacking Vector that's rank == 1
1365+
// (use memref.load/tensor.extract to load a scalar)
1366+
unpackedReadRes =
1367+
dyn_cast<MemRefType>(xferOp.getSource().getType())
1368+
? b.create<memref::LoadOp>(loc, xferOp.getSource(),
1369+
xferIndices)
1370+
.getResult()
1371+
: b.create<tensor::ExtractOp>(loc, xferOp.getSource(),
1372+
xferIndices)
1373+
.getResult();
1374+
}
1375+
return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
13481376
insertionIndices);
13491377
},
13501378
/*outOfBoundsCase=*/

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
13831383
}
13841384

13851385
LogicalResult vector::ExtractOp::verify() {
1386+
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1387+
if (resTy.getRank() == 0)
1388+
return emitError(
1389+
"expected a scalar instead of a 0-d vector as the result type");
1390+
13861391
// Note: This check must come before getMixedPosition() to prevent a crash.
13871392
auto dynamicMarkersCount =
13881393
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2996,6 +3001,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
29963001
}
29973002

29983003
LogicalResult InsertOp::verify() {
3004+
if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3005+
if (srcTy.getRank() == 0)
3006+
return emitError(
3007+
"expected a scalar instead of a 0-d vector as the source operand");
3008+
29993009
SmallVector<OpFoldResult> position = getMixedPosition();
30003010
auto destVectorType = getDestVectorType();
30013011
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
260260
// -----
261261

262262
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263-
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
264-
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
263+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
264+
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
265265
}
266266

267267
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
298298
}
299299

300300
// CHECK-LABEL: @insert_0d
301-
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
301+
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
302302
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
303303
%1 = vector.insert %a, %b[] : f32 into vector<f32>
304-
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
305-
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
306-
return %1, %2 : vector<f32>, vector<2x3xf32>
304+
return %1 : vector<f32>
307305
}
308306

309307
// CHECK-LABEL: @insert_poison_idx

0 commit comments

Comments
 (0)