Skip to content

[mlir][vector] Clarify the semantics of BroadcastOp #101928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 8, 2024

Conversation

banach-space
Copy link
Contributor

Clarifies the semantics of vector.broadcast in the context of scalable
vectors. In particular, broadcasting a unit scalable dim, [1], is not
valid unless there's a match between the output and the input dims.
See the examples below for an illustration:

// VALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x[1]xf32>
// INVALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
// VALID FIXED-WIDTH EQUIVALENT
 %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>

Documentation, the Op verifier and tests are updated accordingly.

Clarifies the semantics of `vector.broadcast` in the context of scalable
vectors. In particular, broadcasting a unit scalable dim, `[1]`, is not
valid unless there's a match between the output and the input dims.
See the examples below for an illustration:

```mlir
// VALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x[1]xf32>
// INVALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
// VALID FIXED-WIDTH EQUIVALENT
 %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
```

Documentation, the Op verifier and tests are updated accordingly.
@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Clarifies the semantics of vector.broadcast in the context of scalable
vectors. In particular, broadcasting a unit scalable dim, [1], is not
valid unless there's a match between the output and the input dims.
See the examples below for an illustration:

// VALID
 %0 = vector.broadcast %arg0 : vector&lt;[1]xf32&gt; to vector&lt;4x[1]xf32&gt;
// INVALID
 %0 = vector.broadcast %arg0 : vector&lt;[1]xf32&gt; to vector&lt;[4]xf32&gt;
// VALID FIXED-WIDTH EQUIVALENT
 %0 = vector.broadcast %arg0 : vector&lt;1xf32&gt; to vector&lt;4xf32&gt;

Documentation, the Op verifier and tests are updated accordingly.


Full diff: https://github.com/llvm/llvm-project/pull/101928.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+35-10)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f..9f61f7c866d3d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -68,9 +68,13 @@ enum class BroadcastableToResult {
   DimensionMismatch = 2,
   SourceTypeNotAVector = 3
 };
+struct VectorDim {
+  int64_t dim;
+  bool scalableFlag;
+};
 BroadcastableToResult
 isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                  std::pair<int, int> *mismatchingDims = nullptr);
+                  std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
 
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..08bff3d5e1382 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -367,6 +367,8 @@ def Vector_BroadcastOp :
                                s_1     x .. x s_j x .. x s_k
                <duplication>         <potential stretch>
        ```
+    * a scalable unit dimeension, `[1]`, must match exactly.
+
     The source operand is duplicated over all the missing leading dimensions
     and stretched over the trailing dimensions where the source has a non-equal
     dimension of 1. These rules imply that any scalar broadcast (k=0) to any
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..673c128932893 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
   return res;
 }
 
-BroadcastableToResult
-mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                                std::pair<int, int> *mismatchingDims) {
+BroadcastableToResult mlir::vector::isBroadcastableTo(
+    Type srcType, VectorType dstVectorType,
+    std::pair<VectorDim, VectorDim> *mismatchingDims) {
   // Broadcast scalar to vector of the same element type.
   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
@@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
   // (all leading dimensions are simply duplicated).
   int64_t lead = dstRank - srcRank;
   for (int64_t r = 0; r < srcRank; ++r) {
+    bool mismatch = false;
+
+    // Check fixed-width dims
     int64_t srcDim = srcVectorType.getDimSize(r);
     int64_t dstDim = dstVectorType.getDimSize(lead + r);
-    if (srcDim != 1 && srcDim != dstDim) {
+    if ((srcDim != 1 && srcDim != dstDim))
+      mismatch = true;
+
+    // Check scalable flags
+    bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
+    bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
+    if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+        (srcDimScalableFlag && !dstDimScalableFlag))
+      mismatch = true;
+
+    if (mismatch) {
       if (mismatchingDims) {
-        mismatchingDims->first = srcDim;
-        mismatchingDims->second = dstDim;
+        mismatchingDims->first.dim = srcDim;
+        mismatchingDims->first.scalableFlag = srcDimScalableFlag;
+
+        mismatchingDims->second.dim = dstDim;
+        mismatchingDims->second.scalableFlag = dstDimScalableFlag;
       }
       return BroadcastableToResult::DimensionMismatch;
     }
@@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
 }
 
 LogicalResult BroadcastOp::verify() {
-  std::pair<int, int> mismatchingDims;
+  std::pair<VectorDim, VectorDim> mismatchingDims;
   BroadcastableToResult res = isBroadcastableTo(
       getSourceType(), getResultVectorType(), &mismatchingDims);
   if (res == BroadcastableToResult::Success)
     return success();
   if (res == BroadcastableToResult::SourceRankHigher)
     return emitOpError("source rank higher than destination rank");
-  if (res == BroadcastableToResult::DimensionMismatch)
-    return emitOpError("dimension mismatch (")
-           << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+  if (res == BroadcastableToResult::DimensionMismatch) {
+    std::string msg =
+        (Twine("dimension mismatch (") +
+         (mismatchingDims.first.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.first.dim) +
+         (mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
+         (mismatchingDims.second.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.second.dim) +
+         (mismatchingDims.second.scalableFlag ? "]" : "") + ")")
+            .str();
+    return emitOpError(msg);
+  }
   if (res == BroadcastableToResult::SourceTypeNotAVector)
     return emitOpError("source type is not a vector");
   llvm_unreachable("unexpected vector.broadcast op error");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 00914c1d1baf6..6dd690be032c7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -35,6 +35,20 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
 
 // -----
 
+func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
+  // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}}
+  %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
+}
+
+// -----
+
+func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
+  // expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
+  %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>
+}
+
+// -----
+
 func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
   // expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
   %1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think there's a missed opportunity here. 😃

struct VectorDim {
int64_t dim;
bool scalableFlag;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was this MR from @MacDue , implementing similar features. I dont know why it got closed though.
#96236

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is unrelated to that discussion. I'm only adding this here to avoid adding new set of params to isBroadcastableTo.

I believe that before we commit to any new wider API, we should discuss the internal representation of VectorType and how scalable dimensions are represented. I am working on a proposal, but that's not yet ready to share 😅 I'm hoping to have something in the coming weeks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouuh, exciting.

bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
(srcDimScalableFlag && !dstDimScalableFlag))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It got me thinking, what would be the expected behaviour of something like:

 %0 = vector.broadcast %arg0 : vector<nxf32> to vector<[n]xf32>

IMO it should not be supported as physically equivalent to a usecase

%1 = vector.broadcast %arg0 : vector<nxf32> to vector<vscale*nxf32>

Which is not invalid for fixed dimensions. Do you think this handles the cases ?

Suggested change
(srcDimScalableFlag && !dstDimScalableFlag))
(srcDimScalableFlag != dstDimScalableFlag))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have e.g. [2] and [4] (i.e. vscale * 2 and vscale * 4), then that's already "rejected" as "mismatching dims":

Is that the case you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case I pointed out was more src = 2 and dest = [2]. srcDim == dstDim, so no mismatch on line 2399. and we have !srcDimScalableFlag so no mismatch on line 2406. Whereas I think this is wrong.

 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice, great catch! In my head I had one case that wouldn't work with !=, but now I am failing to recall that 😂

Let me send an update - thanks very much for pointing this out 🙏🏻

Address comments from Hugo - thank you!
Copy link

github-actions bot commented Aug 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/restrict_vec_bcast branch from 587f08e to 89be55b Compare August 6, 2024 09:54
@banach-space
Copy link
Contributor Author

Ping @kuhar :)

@banach-space
Copy link
Contributor Author

@nujaa Any other suggestions from you or shall I land it?

@nujaa nujaa merged commit 1919db9 into llvm:main Aug 8, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants