Skip to content

Commit 59fbba9

Browse files
authored
[mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat (#66596)
Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcasting op could be either `vector.broadcast` (already supported) as well as `vector.splat` (support added in this patch).
1 parent afd7db4 commit 59fbba9

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
880880
std::function<bool(BitCastOp)> controlFn;
881881
};
882882

883-
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
883+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
884884
/// ```
885885
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
886886
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
891891
/// %r = arith.addi %arg0, %arg1 : index
892892
/// %b = vector.broadcast %r : index to vector<1x4xindex>
893893
/// ```
894+
///
895+
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
896+
/// ops.
894897
struct ReorderElementwiseOpsOnBroadcast final
895898
: public OpTraitRewritePattern<OpTrait::Elementwise> {
896899
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
903906
if (!OpTrait::hasElementwiseMappableTraits(op))
904907
return failure();
905908

906-
// Get the type of the first operand
907-
auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
908-
if (!firstBcast)
909+
// Get the type of the lhs operand
910+
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
911+
if (!lhsBcastOrSplat ||
912+
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
909913
return failure();
910-
auto firstOpType = firstBcast.getOperand().getType();
914+
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
911915

912-
// Make sure that operands are "broadcast"ed from identical (scalar or
913-
// vector) types. That indicates that it's safe to skip the broadcasting of
914-
// operands.
915-
if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
916+
// Make sure that all operands are broadcast from identical types:
917+
// * scalar (`vector.broadcast` + `vector.splat`), or
918+
// * vector (`vector.broadcast`).
919+
// Otherwise the re-ordering wouldn't be safe.
920+
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
916921
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
917-
return (bcast && (bcast.getOperand().getType() == firstOpType));
922+
if (bcast)
923+
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
924+
auto splat = val.getDefiningOp<vector::SplatOp>();
925+
if (splat)
926+
return (splat.getOperand().getType() == lhsBcastOrSplatType);
927+
return false;
918928
})) {
919929
return failure();
920930
}
921931

922-
// Collect the source values
932+
// Collect the source values before broadcasting
923933
SmallVector<Value> srcValues;
924934
srcValues.reserve(op->getNumOperands());
925-
926935
for (Value operand : op->getOperands()) {
927-
srcValues.push_back(
928-
operand.getDefiningOp<vector::BroadcastOp>().getOperand());
936+
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
929937
}
930938

939+
// Create the "elementwise" Op
931940
Operation *elementwiseOp =
932941
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
933-
firstOpType, op->getAttrs());
942+
lhsBcastOrSplatType, op->getAttrs());
934943

944+
// Replace the original Op with the elementwise Op
935945
auto vectorType = op->getResultTypes()[0];
936946
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
937947
op, vectorType, elementwiseOp->getResults());

mlir/test/Dialect/Vector/sink-vector-broadcast.mlir

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
22

3-
// CHECK-LABEL: func.func @broadcast_scalar(
3+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
44
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
55
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
66
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
77
// CHECK: return %[[BCAST]] : vector<1x4xindex>
8-
// CHECK: }
98

10-
func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
9+
func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
1110
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
1211
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
1312
%2 = arith.addi %0, %1 : vector<1x4xindex>
@@ -16,20 +15,51 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
1615

1716
// -----
1817

18+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
19+
// CHECK-SAME: %[[ARG1:.*]]: index,
20+
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
21+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
22+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
23+
// CHECK: return %[[BCAST]] : vector<1x4xindex>
24+
func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
25+
%0 = vector.splat %arg1 : vector<1x4xindex>
26+
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
27+
%2 = arith.addi %0, %1 : vector<1x4xindex>
28+
return %2 : vector<1x4xindex>
29+
}
30+
31+
// -----
32+
1933
// CHECK-LABEL: func.func @broadcast_vector(
2034
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
2135
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
2236
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
2337
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
2438
// CHECK: return %[[BCAST]] : vector<3x4xf32>
25-
// CHECK: }
2639

2740
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
2841
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
2942
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
3043
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
3144
return %2 : vector<3x4xf32>
3245
}
46+
47+
// -----
48+
49+
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
50+
// CHECK-SAME: %[[ARG1:.*]]: index,
51+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
52+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
53+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
54+
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
55+
// CHECK: return %[[ADD]] : vector<1x4xindex>
56+
func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
57+
%0 = vector.splat %arg1 : vector<1x4xindex>
58+
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
59+
%2 = arith.addi %0, %1 : vector<1x4xindex>
60+
return %2 : vector<1x4xindex>
61+
}
62+
3363
// -----
3464

3565
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
3868
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
3969
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
4070
// CHECK: return %[[ADD]] : vector<4xi32>
41-
// CHECK: }
4271

4372
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
4473
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>

0 commit comments

Comments
 (0)