Skip to content

Commit def37f7

Browse files
authored
[mlir][vector] add unroll pattern for broadcast (#142011)
This PR adds `UnrollBroadcastPattern` to `VectorUnroll` transform. To support this, it also extends `BroadcastOp` definition with `VectorUnrollOpInterface`
1 parent 4dcc159 commit def37f7

File tree

5 files changed

+150
-12
lines changed

5 files changed

+150
-12
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
347347

348348
def Vector_BroadcastOp :
349349
Vector_Op<"broadcast", [Pure,
350+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
350351
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
351352
PredOpTrait<"source operand and result have same element type",
352353
TCresVTEtIsSameAsOpBase<0, 0>>]>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,6 +2522,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
25222522
setResultRanges(getResult(), argRanges.front());
25232523
}
25242524

2525+
std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2526+
return llvm::to_vector<4>(getResultVectorType().getShape());
2527+
}
2528+
25252529
/// Return the dimensions of the result vector that were formerly ones in the
25262530
/// source tensor and thus correspond to "dim-1" broadcasting.
25272531
static llvm::SetVector<int64_t>

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631631
vector::UnrollVectorOptions options;
632632
};
633633

634+
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
635+
UnrollBroadcastPattern(MLIRContext *context,
636+
const vector::UnrollVectorOptions &options,
637+
PatternBenefit benefit = 1)
638+
: OpRewritePattern<vector::BroadcastOp>(context, benefit),
639+
options(options) {}
640+
641+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
642+
PatternRewriter &rewriter) const override {
643+
auto targetShape = getTargetShape(options, broadcastOp);
644+
if (!targetShape)
645+
return failure();
646+
647+
Location loc = broadcastOp.getLoc();
648+
VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
649+
VectorType resType = broadcastOp.getResultVectorType();
650+
VectorType targetType =
651+
resType.cloneWith(*targetShape, resType.getElementType());
652+
Value result = rewriter.create<arith::ConstantOp>(
653+
loc, resType, rewriter.getZeroAttr(resType));
654+
655+
SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
656+
SmallVector<int64_t> strides(originalShape.size(), 1);
657+
658+
for (SmallVector<int64_t> offsets :
659+
StaticTileOffsetRange(originalShape, *targetShape)) {
660+
Value newSrc;
661+
if (!srcType) {
662+
// Scalar to vector broadcast.
663+
newSrc = broadcastOp.getSource();
664+
} else {
665+
// Vector to vector broadcast.
666+
int64_t rank = srcType.getRank();
667+
SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
668+
SmallVector<int64_t> srcShape(targetShape->end() - rank,
669+
targetShape->end());
670+
SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
671+
// adjust the offset and shape for src if the corresponding dim is 1.
672+
for (int64_t i = 0; i < rank; ++i) {
673+
if (srcType.getDimSize(i) == 1) {
674+
srcOffsets[i] = 0;
675+
srcShape[i] = 1;
676+
}
677+
}
678+
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
679+
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
680+
}
681+
682+
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
683+
newSrc, targetType);
684+
685+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
686+
loc, newOp->getResult(0), result, offsets, strides);
687+
}
688+
689+
rewriter.replaceOp(broadcastOp, result);
690+
return success();
691+
}
692+
693+
private:
694+
vector::UnrollVectorOptions options;
695+
};
696+
634697
} // namespace
635698

636699
void mlir::vector::populateVectorUnrollPatterns(
637700
RewritePatternSet &patterns, const UnrollVectorOptions &options,
638701
PatternBenefit benefit) {
639-
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
640-
UnrollContractionPattern, UnrollElementwisePattern,
641-
UnrollReductionPattern, UnrollMultiReductionPattern,
642-
UnrollTransposePattern, UnrollGatherPattern>(
643-
patterns.getContext(), options, benefit);
702+
patterns
703+
.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
704+
UnrollContractionPattern, UnrollElementwisePattern,
705+
UnrollReductionPattern, UnrollMultiReductionPattern,
706+
UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
707+
patterns.getContext(), options, benefit);
644708
}

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
196196
// CHECK-LABEL: func @negative_vector_fma_3d
197197
// CHECK-NOT: vector.extract_strided_slice
198198
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
199-
// CHECK: return
199+
// CHECK: return
200200

201201
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
202202
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,70 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
311311
// BATCHED-COUNT-16: vector.contract
312312
// BATCHED-NOT: vector.contract
313313
// BATCHED: return
314+
315+
316+
func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
317+
%0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
318+
return %0 : vector<4x4xf32>
319+
}
320+
321+
// CHECK-LABEL: func @vector_broadcast
322+
// CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
323+
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
324+
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
325+
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
326+
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
327+
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
328+
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
329+
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
330+
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
331+
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
332+
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
333+
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
334+
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
335+
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
336+
// CHECK: return [[r3]] : vector<4x4xf32>
337+
338+
func.func @vector_broadcast_with_leading_unit_dim(%v: vector<1x4xf32>) -> vector<4x4xf32> {
339+
%0 = vector.broadcast %v : vector<1x4xf32> to vector<4x4xf32>
340+
return %0 : vector<4x4xf32>
341+
}
342+
343+
// CHECK-LABEL: func.func @vector_broadcast_with_leading_unit_dim
344+
// CHECK-SAME: ([[arg0:%.+]]: vector<1x4xf32>) -> vector<4x4xf32> {
345+
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
346+
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
347+
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<1x2xf32> to vector<2x2xf32>
348+
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
349+
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
350+
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<1x2xf32> to vector<2x2xf32>
351+
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
352+
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
353+
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<1x2xf32> to vector<2x2xf32>
354+
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
355+
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
356+
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<1x2xf32> to vector<2x2xf32>
357+
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
358+
// CHECK: return [[r3]] : vector<4x4xf32>
359+
360+
func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector<4x4xf32> {
361+
%0 = vector.broadcast %v : vector<4x1xf32> to vector<4x4xf32>
362+
return %0 : vector<4x4xf32>
363+
}
364+
365+
// CHECK-LABEL: func.func @vector_broadcast_with_tailing_unit_dim
366+
// CHECK-SAME: ([[arg0:%.+]]: vector<4x1xf32>) -> vector<4x4xf32> {
367+
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
368+
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
369+
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2x1xf32> to vector<2x2xf32>
370+
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
371+
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
372+
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2x1xf32> to vector<2x2xf32>
373+
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
374+
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
375+
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2x1xf32> to vector<2x2xf32>
376+
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
377+
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
378+
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
379+
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
380+
// CHECK: return [[r3]] : vector<4x4xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
157157
MLIRContext *ctx = &getContext();
158158
RewritePatternSet patterns(ctx);
159159
populateVectorUnrollPatterns(
160-
patterns, UnrollVectorOptions()
161-
.setNativeShape(ArrayRef<int64_t>{2, 2})
162-
.setFilterConstraint([](Operation *op) {
163-
return success(isa<arith::AddFOp, vector::FMAOp,
164-
vector::MultiDimReductionOp>(op));
165-
}));
160+
patterns,
161+
UnrollVectorOptions()
162+
.setNativeShape(ArrayRef<int64_t>{2, 2})
163+
.setFilterConstraint([](Operation *op) {
164+
return success(
165+
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
166+
vector::BroadcastOp>(op));
167+
}));
166168
populateVectorUnrollPatterns(
167169
patterns, UnrollVectorOptions()
168170
.setNativeShape(ArrayRef<int64_t>{2})

0 commit comments

Comments
 (0)