Skip to content

[mlir][vector] Clarify the semantics of BroadcastOp #101928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ enum class BroadcastableToResult {
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};

struct VectorDim {
int64_t dim;
bool isScalable;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was this MR from @MacDue , implementing similar features. I dont know why it got closed though.
#96236

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is unrelated to that discussion. I'm only adding this here to avoid adding new set of params to isBroadcastableTo.

I believe that before we commit to any new wider API, we should discuss the internal representation of VectorType and how scalable dimensions are represented. I am working on a proposal, but that's not yet ready to share 😅 I'm hoping to have something in the coming weeks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouuh, exciting.

BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims = nullptr);
std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);

/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def Vector_BroadcastOp :
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
* in addition, any scalable unit dimension, `[1]`, must match exactly.

The source operand is duplicated over all the missing leading dimensions
and stretched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
Expand Down
50 changes: 37 additions & 13 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
return res;
}

BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
Expand All @@ -2390,13 +2390,31 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim) {
if (mismatchingDims) {
mismatchingDims->first = srcDim;
mismatchingDims->second = dstDim;
for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
// Have mismatching dims (in the sense of vector.broadcast semantics) been
// encountered?
bool foundMismatchingDims = false;

// Check fixed-width dims.
int64_t srcDim = srcVectorType.getDimSize(dimIdx);
int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
if (srcDim != 1 && srcDim != dstDim)
foundMismatchingDims = true;

// Check scalable flags.
bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
(srcDimScalableFlag != dstDimScalableFlag))
foundMismatchingDims = true;

if (foundMismatchingDims) {
if (mismatchingDims != nullptr) {
mismatchingDims->first.dim = srcDim;
mismatchingDims->first.isScalable = srcDimScalableFlag;

mismatchingDims->second.dim = dstDim;
mismatchingDims->second.isScalable = dstDimScalableFlag;
}
return BroadcastableToResult::DimensionMismatch;
}
Expand All @@ -2406,16 +2424,22 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
}

LogicalResult BroadcastOp::verify() {
std::pair<int, int> mismatchingDims;
std::pair<VectorDim, VectorDim> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return emitOpError("source rank higher than destination rank");
if (res == BroadcastableToResult::DimensionMismatch)
if (res == BroadcastableToResult::DimensionMismatch) {
return emitOpError("dimension mismatch (")
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
<< (mismatchingDims.first.isScalable ? "[" : "")
<< mismatchingDims.first.dim
<< (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
<< (mismatchingDims.second.isScalable ? "[" : "")
<< mismatchingDims.second.dim
<< (mismatchingDims.second.isScalable ? "]" : "") << ")";
}
if (res == BroadcastableToResult::SourceTypeNotAVector)
return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {

// -----

func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
}

// -----

func.func @broadcast_fixed_to_scalable(%arg0: vector<2xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch (2 vs. [2])}}
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>
}

// -----

func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>
}

// -----

func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
Expand Down
Loading