Skip to content

Commit 74d843c

Browse files
committed
fixup! [mlir][vector] Clarify the semantics of BroadcastOp
Address comments from Hugo - thank you!
1 parent 7c08a8b commit 74d843c

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def Vector_BroadcastOp :
367367
s_1 x .. x s_j x .. x s_k
368368
<duplication> <potential stretch>
369369
```
370-
* a scalable unit dimeension, `[1]`, must match exactly.
370+
* a scalable unit dimension, `[1]`, must match exactly.
371371

372372
The source operand is duplicated over all the missing leading dimensions
373373
and stretched over the trailing dimensions where the source has a non-equal

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2403,7 +2403,7 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
24032403
bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
24042404
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
24052405
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2406-
(srcDimScalableFlag && !dstDimScalableFlag))
2406+
(srcDimScalableFlag != dstDimScalableFlag))
24072407
mismatch = true;
24082408

24092409
if (mismatch) {

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
4242

4343
// -----
4444

45+
func.func @broadcast_fixed_to_scalable(%arg0: vector<2xf32>) {
46+
// expected-error@+1 {{'vector.broadcast' op dimension mismatch (2 vs. [2])}}
47+
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>
48+
}
49+
50+
// -----
51+
4552
func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
4653
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
4754
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>

0 commit comments

Comments
 (0)