Skip to content

Commit f3abfcb

Browse files
committed
changes in vector transforms
1 parent b00d4f2 commit f3abfcb

File tree

6 files changed

+77
-54
lines changed

6 files changed

+77
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ using namespace mlir;
2828
using namespace mlir::vector;
2929

3030
namespace {
31-
/// Progressive lowering of BroadcastOp.
31+
32+
/// Convert a vector.broadcast without a scalar operand to a lower rank
33+
/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
34+
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
3235
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
3336
public:
3437
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
4043
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
4144
Type eltType = dstType.getElementType();
4245

43-
// Scalar to any vector can use splat.
44-
if (!srcType) {
45-
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
46-
return success();
47-
}
46+
// A broadcast from a scalar is considered to be in the lowered form.
47+
if (!srcType)
48+
return failure();
4849

4950
// Determine rank of source and destination.
5051
int64_t srcRank = srcType.getRank();
5152
int64_t dstRank = dstType.getRank();
5253

53-
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
54-
if (srcRank <= 1 && dstRank == 1) {
55-
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
56-
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
54+
if (srcType.getNumElements() == 1 && dstRank == 1) {
55+
SmallVector<int64_t> fullRankPosition(srcRank, 0);
56+
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
57+
fullRankPosition);
58+
assert(!isa<VectorType>(ext.getType()) && "expected scalar");
59+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
5760
return success();
5861
}
5962

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
468468
read, "vector type is not rank 1, can't create masked load, needs "
469469
"VectorToSCF");
470470

471-
Value fill = vector::SplatOp::create(
471+
Value fill = vector::BroadcastOp::create(
472472
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473473
res = vector::MaskedLoadOp::create(
474474
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
303303
// Extract/insert on a lower ranked extract strided slice op.
304304
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
305305
rewriter.getZeroAttr(elemType));
306-
Value res = SplatOp::create(rewriter, loc, dstType, zero);
306+
Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
307307
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308308
off += stride, ++idx) {
309309
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
939939

940940
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
941941
rewriter.getZeroAttr(elemType));
942-
Value res = SplatOp::create(rewriter, loc, castDstType, zero);
942+
Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
943943

944944
SmallVector<int64_t> sliceShape = {castDstLastDim};
945945
SmallVector<int64_t> strides = {1};
@@ -987,6 +987,23 @@ static Type cloneOrReplace(Type type, Type newElementType) {
987987
return newElementType;
988988
}
989989

990+
/// If `value` is the result of a splat or broadcast operation, return the input
991+
/// of the splat/broadcast operation.
992+
static Value getBroadcastLikeSource(Value value) {
993+
994+
Operation *op = value.getDefiningOp();
995+
if (!op)
996+
return {};
997+
998+
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
999+
return broadcast.getSource();
1000+
1001+
if (auto splat = dyn_cast<vector::SplatOp>(op))
1002+
return splat.getInput();
1003+
1004+
return {};
1005+
}
1006+
9901007
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
9911008
///
9921009
/// Example:
@@ -1026,39 +1043,40 @@ struct ReorderElementwiseOpsOnBroadcast final
10261043
}
10271044

10281045
Type resultElemType = resultType.getElementType();
1046+
10291047
// Get the type of the first non-constant operand
1030-
Operation *firstBroadcastOrSplat = nullptr;
1048+
// Operation *firstBroadcastOrSplat = nullptr;
1049+
Value splatSource;
10311050
for (Value operand : op->getOperands()) {
10321051
Operation *definingOp = operand.getDefiningOp();
10331052
if (!definingOp)
10341053
return failure();
10351054
if (definingOp->hasTrait<OpTrait::ConstantLike>())
10361055
continue;
1037-
if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
1038-
return failure();
1039-
firstBroadcastOrSplat = definingOp;
1056+
splatSource = getBroadcastLikeSource(operand);
10401057
break;
10411058
}
1042-
if (!firstBroadcastOrSplat)
1059+
if (!splatSource)
1060+
// TODO: why?
10431061
return failure();
1044-
Type unbroadcastResultType = cloneOrReplace(
1045-
firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
1062+
1063+
Type unbroadcastResultType =
1064+
cloneOrReplace(splatSource.getType(), resultElemType);
1065+
1066+
Type lhsBcastOrSplatType = splatSource.getType();
10461067

10471068
// Make sure that all operands are broadcast from identically-shaped types:
10481069
// * scalar (`vector.broadcast` + `vector.splat`), or
10491070
// * vector (`vector.broadcast`).
10501071
// Otherwise the re-ordering wouldn't be safe.
1051-
if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
1052-
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
1053-
return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
1054-
unbroadcastResultType);
1055-
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
1056-
return haveSameShapeAndScaling(splatOp.getOperand().getType(),
1057-
unbroadcastResultType);
1072+
if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
1073+
if (auto source = getBroadcastLikeSource(val))
1074+
return source.getType() == lhsBcastOrSplatType;
10581075
SplatElementsAttr splatConst;
10591076
return matchPattern(val, m_Constant(&splatConst));
10601077
})) {
1061-
return failure();
1078+
return rewriter.notifyMatchFailure(
1079+
op, "not all operands are broadcasts from the sametype");
10621080
}
10631081

10641082
// Collect the source values before broadcasting
@@ -1287,15 +1305,17 @@ class StoreOpFromSplatOrBroadcast final
12871305
return rewriter.notifyMatchFailure(
12881306
op, "only 1-element vectors are supported");
12891307

1290-
Operation *splat = op.getValueToStore().getDefiningOp();
1291-
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1292-
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1308+
Value toStore = op.getValueToStore();
1309+
Value source = getBroadcastLikeSource(toStore);
1310+
if (!source)
1311+
return rewriter.notifyMatchFailure(
1312+
op, "value to store is not from a broadcast");
12931313

12941314
// Checking for single use so we can remove splat.
1315+
Operation *splat = toStore.getDefiningOp();
12951316
if (!splat->hasOneUse())
12961317
return rewriter.notifyMatchFailure(op, "expected single op use");
12971318

1298-
Value source = splat->getOperand(0);
12991319
Value base = op.getBase();
13001320
ValueRange indices = op.getIndices();
13011321

@@ -1345,13 +1365,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
13451365
// Add in an offset if requested.
13461366
if (off) {
13471367
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1348-
Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
1368+
Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
13491369
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
13501370
}
13511371
// Construct the vector comparison.
13521372
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
13531373
Value bounds =
1354-
vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
1374+
vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
13551375
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
13561376
indices, bounds);
13571377
}

mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
44
// CHECK-SAME: %[[A:.*0]]: f32
5-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
5+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
66
// CHECK: return %[[T0]] : vector<2xf32>
77

88
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
1212

1313
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
1414
// CHECK-SAME: %[[A:.*0]]: f32
15-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
15+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
1616
// CHECK: return %[[T0]] : vector<2x3xf32>
1717

1818
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
2222

2323
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
2424
// CHECK-SAME: %[[A:.*0]]: f32
25-
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
25+
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
2626
// CHECK: return %[[T0]] : vector<2x3x4xf32>
2727

2828
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
8787
// CHECK-LABEL: func @broadcast_stretch
8888
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
8989
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
90-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
90+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
9191
// CHECK: return %[[T1]] : vector<4xf32>
9292

9393
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
113113
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
114114
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
115115
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
116-
// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
116+
// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
117117
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
118118
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
119-
// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
119+
// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
120120
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
121121
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
122-
// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
122+
// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
123123
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
124124
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
125-
// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
125+
// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
126126
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
127127
// CHECK: return %[[T15]] : vector<4x3xf32>
128128

mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
66
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
77
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
8-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
8+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
99
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
1010
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
1111
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
12-
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
12+
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
1313
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
1414
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
1515
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
2626
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
2727
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
2828
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
29-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
29+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
3030
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
3131
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
3232
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
3333
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
34-
// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
34+
// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
3535
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
3636
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
3737
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
4949
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
5050
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
5151
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
52-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
52+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
5353
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
5454
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
5555
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
56-
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
56+
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
5757
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
5858
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
5959
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
6969
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
7070
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
7171
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
72-
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
72+
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
7373
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
7474
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
7575
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
7676
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
7777
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
78-
// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
78+
// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
7979
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
8080
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
8181
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
9191
// CHECK-LABEL: func @axpy_fp(
9292
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
9393
// CHECK-SAME: %[[B:.*1]]: f32)
94-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
94+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
9595
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
9696
// CHECK: return %[[T1]] : vector<16xf32>
9797
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
103103
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
104104
// CHECK-SAME: %[[B:.*1]]: f32,
105105
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
106-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
106+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
107107
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
108108
// CHECK: return %[[T1]] : vector<16xf32>
109109
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
114114
// CHECK-LABEL: func @axpy_int(
115115
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
116116
// CHECK-SAME: %[[B:.*1]]: i32)
117-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
117+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
118118
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
119119
// CHECK: return %[[T1]] : vector<16xi32>
120120
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
126126
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
127127
// CHECK-SAME: %[[B:.*1]]: i32,
128128
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
129-
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
129+
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
130130
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
131131
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
132132
// CHECK: return %[[T2]] : vector<16xi32>

0 commit comments

Comments
 (0)