Skip to content

Commit 3d1e82b

Browse files
committed
address review comments, fix rebase conflicts WRT
1 parent c3c69d0 commit 3d1e82b

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using namespace mlir::vector;
2929

3030
namespace {
3131

32-
/// Convert a vector.broadcast without a scalar operand to a lower rank
32+
/// Convert a vector.broadcast with a vector operand to a lower rank
3333
/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
3434
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
3535
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
@@ -45,12 +45,15 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
4545

4646
// A broadcast from a scalar is considered to be in the lowered form.
4747
if (!srcType)
48-
return failure();
48+
return rewriter.notifyMatchFailure(
49+
op, "broadcast from scalar already in lowered form");
4950

5051
// Determine rank of source and destination.
5152
int64_t srcRank = srcType.getRank();
5253
int64_t dstRank = dstType.getRank();
5354

55+
// Here we are broadcasting to a rank-1 vector. Ensure that the source is a
56+
// scalar.
5457
if (srcRank <= 1 && dstRank == 1) {
5558
SmallVector<int64_t> fullRankPosition(srcRank, 0);
5659
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,6 @@ struct ReorderElementwiseOpsOnBroadcast final
10451045
Type resultElemType = resultType.getElementType();
10461046

10471047
// Get the type of the first non-constant operand
1048-
// Operation *firstBroadcastOrSplat = nullptr;
10491048
Value splatSource;
10501049
for (Value operand : op->getOperands()) {
10511050
Operation *definingOp = operand.getDefiningOp();
@@ -1057,26 +1056,24 @@ struct ReorderElementwiseOpsOnBroadcast final
10571056
break;
10581057
}
10591058
if (!splatSource)
1060-
// TODO: why?
10611059
return failure();
1062-
10631060
Type unbroadcastResultType =
10641061
cloneOrReplace(splatSource.getType(), resultElemType);
10651062

1066-
Type lhsBcastOrSplatType = splatSource.getType();
1067-
10681063
// Make sure that all operands are broadcast from identically-shaped types:
10691064
// * scalar (`vector.broadcast` + `vector.splat`), or
10701065
// * vector (`vector.broadcast`).
10711066
// Otherwise the re-ordering wouldn't be safe.
1072-
if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
1067+
if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
10731068
if (auto source = getBroadcastLikeSource(val))
1074-
return source.getType() == lhsBcastOrSplatType;
1069+
return haveSameShapeAndScaling(source.getType(),
1070+
splatSource.getType());
10751071
SplatElementsAttr splatConst;
10761072
return matchPattern(val, m_Constant(&splatConst));
10771073
})) {
10781074
return rewriter.notifyMatchFailure(
1079-
op, "not all operands are broadcasts from the sametype");
1075+
op,
1076+
"not all operands are constants or broadcasts from the same type");
10801077
}
10811078

10821079
// Collect the source values before broadcasting

0 commit comments

Comments
 (0)