-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Clarify the semantics of BroadcastOp #101928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Clarifies the semantics of `vector.broadcast` in the context of scalable vectors. In particular, broadcasting a unit scalable dim, `[1]`, is not valid unless there's a match between the output and the input dims. See the examples below for an illustration: ```mlir // VALID %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x[1]xf32> // INVALID %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32> // VALID FIXED-WIDTH EQUIVALENT %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> ``` Documentation, the Op verifier and tests are updated accordingly.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesClarifies the semantics of // VALID
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x[1]xf32>
// INVALID
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
// VALID FIXED-WIDTH EQUIVALENT
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> Documentation, the Op verifier and tests are updated accordingly. Full diff: https://github.com/llvm/llvm-project/pull/101928.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f..9f61f7c866d3d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -68,9 +68,13 @@ enum class BroadcastableToResult {
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};
+struct VectorDim {
+ int64_t dim;
+ bool scalableFlag;
+};
BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
- std::pair<int, int> *mismatchingDims = nullptr);
+ std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..08bff3d5e1382 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -367,6 +367,8 @@ def Vector_BroadcastOp :
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
+ * a scalable unit dimeension, `[1]`, must match exactly.
+
The source operand is duplicated over all the missing leading dimensions
and stretched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..673c128932893 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
return res;
}
-BroadcastableToResult
-mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
- std::pair<int, int> *mismatchingDims) {
+BroadcastableToResult mlir::vector::isBroadcastableTo(
+ Type srcType, VectorType dstVectorType,
+ std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
@@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
+ bool mismatch = false;
+
+ // Check fixed-width dims
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
- if (srcDim != 1 && srcDim != dstDim) {
+ if ((srcDim != 1 && srcDim != dstDim))
+ mismatch = true;
+
+ // Check scalable flags
+ bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
+ bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
+ if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+ (srcDimScalableFlag && !dstDimScalableFlag))
+ mismatch = true;
+
+ if (mismatch) {
if (mismatchingDims) {
- mismatchingDims->first = srcDim;
- mismatchingDims->second = dstDim;
+ mismatchingDims->first.dim = srcDim;
+ mismatchingDims->first.scalableFlag = srcDimScalableFlag;
+
+ mismatchingDims->second.dim = dstDim;
+ mismatchingDims->second.scalableFlag = dstDimScalableFlag;
}
return BroadcastableToResult::DimensionMismatch;
}
@@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
}
LogicalResult BroadcastOp::verify() {
- std::pair<int, int> mismatchingDims;
+ std::pair<VectorDim, VectorDim> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return emitOpError("source rank higher than destination rank");
- if (res == BroadcastableToResult::DimensionMismatch)
- return emitOpError("dimension mismatch (")
- << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+ if (res == BroadcastableToResult::DimensionMismatch) {
+ std::string msg =
+ (Twine("dimension mismatch (") +
+ (mismatchingDims.first.scalableFlag ? "[" : "") +
+ std::to_string(mismatchingDims.first.dim) +
+ (mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
+ (mismatchingDims.second.scalableFlag ? "[" : "") +
+ std::to_string(mismatchingDims.second.dim) +
+ (mismatchingDims.second.scalableFlag ? "]" : "") + ")")
+ .str();
+ return emitOpError(msg);
+ }
if (res == BroadcastableToResult::SourceTypeNotAVector)
return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 00914c1d1baf6..6dd690be032c7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -35,6 +35,20 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
// -----
+func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
+ // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}}
+ %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
+}
+
+// -----
+
+func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
+ // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
+ %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>
+}
+
+// -----
+
func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I think there's a missed opportunity here. 😃
struct VectorDim { | ||
int64_t dim; | ||
bool scalableFlag; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is unrelated to that discussion. I'm only adding this here to avoid adding new set of params to isBroadcastableTo
.
I believe that before we commit to any new wider API, we should discuss the internal representation of VectorType
and how scalable dimensions are represented. I am working on a proposal, but that's not yet ready to share 😅 I'm hoping to have something in the coming weeks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouuh, exciting.
bool srcDimScalableFlag = srcVectorType.getScalableDims()[r]; | ||
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r]; | ||
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) || | ||
(srcDimScalableFlag && !dstDimScalableFlag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It got me thinking, what would be the expected behaviour of something like:
%0 = vector.broadcast %arg0 : vector<nxf32> to vector<[n]xf32>
IMO it should not be supported as physically equivalent to a usecase
%1 = vector.broadcast %arg0 : vector<nxf32> to vector<vscale*nxf32>
Which is not invalid for fixed dimensions. Do you think this handles the cases ?
(srcDimScalableFlag && !dstDimScalableFlag)) | |
(srcDimScalableFlag != dstDimScalableFlag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have e.g. [2]
and [4]
(i.e. vscale * 2
and vscale * 4
), then that's already "rejected" as "mismatching dims":
if (srcDim != 1 && srcDim != dstDim) {
Is that the case you had in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case I pointed out was more src = 2
and dest = [2]
. srcDim == dstDim, so no mismatch on line 2399. and we have !srcDimScalableFlag
so no mismatch on line 2406. Whereas I think this is wrong.
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, nice, great catch! In my head I had one case that wouldn't work with !=
, but now I am failing to recall that 😂
Let me send an update - thanks very much for pointing this out 🙏🏻
Address comments from Hugo - thank you!
✅ With the latest revision this PR passed the C/C++ code formatter. |
Address comments from Jakub
587f08e
to
89be55b
Compare
More comments and simplifications
…roadcastOp Avoid using llvm::Twine
Ping @kuhar :) |
@nujaa Any other suggestions from you or shall I land it? |
Clarifies the semantics of
vector.broadcast
in the context of scalablevectors. In particular, broadcasting a unit scalable dim,
[1]
, is notvalid unless there's a match between the output and the input dims.
See the examples below for an illustration:
Documentation, the Op verifier and tests are updated accordingly.