Skip to content

Commit 89be55b

Browse files
committed
fixup! fixup! [mlir][vector] Clarify the semantics of BroadcastOp
Address comments from Jakub
1 parent 74d843c commit 89be55b

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ enum class BroadcastableToResult {
6868
DimensionMismatch = 2,
6969
SourceTypeNotAVector = 3
7070
};
71+
7172
struct VectorDim {
7273
int64_t dim;
73-
bool scalableFlag;
74+
bool isScalable;
7475
};
7576
BroadcastableToResult
7677
isBroadcastableTo(Type srcType, VectorType dstVectorType,

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def Vector_BroadcastOp :
367367
s_1 x .. x s_j x .. x s_k
368368
<duplication> <potential stretch>
369369
```
370-
* a scalable unit dimension, `[1]`, must match exactly.
370+
* in addition, any scalable unit dimension, `[1]`, must match exactly.
371371

372372
The source operand is duplicated over all the missing leading dimensions
373373
and stretched over the trailing dimensions where the source has a non-equal

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,29 +2390,29 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
23902390
// Source has an exact match or singleton value for all trailing dimensions
23912391
// (all leading dimensions are simply duplicated).
23922392
int64_t lead = dstRank - srcRank;
2393-
for (int64_t r = 0; r < srcRank; ++r) {
2393+
for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
23942394
bool mismatch = false;
23952395

2396-
// Check fixed-width dims
2397-
int64_t srcDim = srcVectorType.getDimSize(r);
2398-
int64_t dstDim = dstVectorType.getDimSize(lead + r);
2399-
if ((srcDim != 1 && srcDim != dstDim))
2396+
// Check fixed-width dims.
2397+
int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2398+
int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2399+
if (srcDim != 1 && srcDim != dstDim)
24002400
mismatch = true;
24012401

2402-
// Check scalable flags
2403-
bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
2404-
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
2402+
// Check scalable flags.
2403+
bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2404+
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
24052405
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
24062406
(srcDimScalableFlag != dstDimScalableFlag))
24072407
mismatch = true;
24082408

24092409
if (mismatch) {
2410-
if (mismatchingDims) {
2410+
if (mismatchingDims != nullptr) {
24112411
mismatchingDims->first.dim = srcDim;
2412-
mismatchingDims->first.scalableFlag = srcDimScalableFlag;
2412+
mismatchingDims->first.isScalable = srcDimScalableFlag;
24132413

24142414
mismatchingDims->second.dim = dstDim;
2415-
mismatchingDims->second.scalableFlag = dstDimScalableFlag;
2415+
mismatchingDims->second.isScalable = dstDimScalableFlag;
24162416
}
24172417
return BroadcastableToResult::DimensionMismatch;
24182418
}
@@ -2430,15 +2430,14 @@ LogicalResult BroadcastOp::verify() {
24302430
if (res == BroadcastableToResult::SourceRankHigher)
24312431
return emitOpError("source rank higher than destination rank");
24322432
if (res == BroadcastableToResult::DimensionMismatch) {
2433-
std::string msg =
2434-
(Twine("dimension mismatch (") +
2435-
(mismatchingDims.first.scalableFlag ? "[" : "") +
2436-
std::to_string(mismatchingDims.first.dim) +
2437-
(mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
2438-
(mismatchingDims.second.scalableFlag ? "[" : "") +
2439-
std::to_string(mismatchingDims.second.dim) +
2440-
(mismatchingDims.second.scalableFlag ? "]" : "") + ")")
2441-
.str();
2433+
std::string msg = (Twine("dimension mismatch (") +
2434+
(mismatchingDims.first.isScalable ? "[" : "") +
2435+
std::to_string(mismatchingDims.first.dim) +
2436+
(mismatchingDims.first.isScalable ? "]" : "") + " vs. " +
2437+
(mismatchingDims.second.isScalable ? "[" : "") +
2438+
std::to_string(mismatchingDims.second.dim) +
2439+
(mismatchingDims.second.isScalable ? "]" : "") + ")")
2440+
.str();
24422441
return emitOpError(msg);
24432442
}
24442443
if (res == BroadcastableToResult::SourceTypeNotAVector)

0 commit comments

Comments
 (0)