Skip to content

Commit f864c7a

Browse files
committed
[mlir][vector] Add more tests for ConvertVectorToLLVM (1/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.bitcast * vector.broadcast Note, this has uncovered some missing logic in `BroadcastOpLowering`. This PR fixes the most basic cases where the scalable flags were dropped and the generated code was incorrect. The `BroadcastOpLowering` pattern is effectively disabled for scalable vectors in more complex cases where an SCF loop would be required to loop over the scalable dims, e.g.: ```mlir %0 = vector.broadcast %arg0 : vector<[4]x1x2xf32> to vector<[4]x3x2xf32> ``` These cases are marked as "Stetch not at start" in the code. In those case, support for scalable vectors is left as a TODO.
1 parent 7c08a8b commit f864c7a

File tree

2 files changed

+242
-1
lines changed

2 files changed

+242
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
125125
// ..
126126
// %x = [%a,%b,%c,%d]
127127
VectorType resType =
128-
VectorType::get(dstType.getShape().drop_front(), eltType);
128+
VectorType::get(dstType.getShape().drop_front(), eltType,
129+
dstType.getScalableDims().drop_front());
129130
Value result = rewriter.create<arith::ConstantOp>(
130131
loc, dstType, rewriter.getZeroAttr(dstType));
131132
if (m == 0) {
@@ -136,6 +137,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
136137
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
137138
} else {
138139
// Stetch not at start.
140+
if (dstType.getScalableDims()[0]) {
141+
// TODO: For scalable vectors we should emit an scf.for loop.
142+
return failure();
143+
}
139144
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
140145
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
141146
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);

0 commit comments

Comments
 (0)