@@ -2390,29 +2390,29 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
2390
2390
// Source has an exact match or singleton value for all trailing dimensions
2391
2391
// (all leading dimensions are simply duplicated).
2392
2392
int64_t lead = dstRank - srcRank;
2393
- for (int64_t r = 0 ; r < srcRank; ++r ) {
2393
+ for (int64_t dimIdx = 0 ; dimIdx < srcRank; ++dimIdx ) {
2394
2394
bool mismatch = false ;
2395
2395
2396
- // Check fixed-width dims
2397
- int64_t srcDim = srcVectorType.getDimSize (r );
2398
- int64_t dstDim = dstVectorType.getDimSize (lead + r );
2399
- if (( srcDim != 1 && srcDim != dstDim) )
2396
+ // Check fixed-width dims.
2397
+ int64_t srcDim = srcVectorType.getDimSize (dimIdx );
2398
+ int64_t dstDim = dstVectorType.getDimSize (lead + dimIdx );
2399
+ if (srcDim != 1 && srcDim != dstDim)
2400
2400
mismatch = true ;
2401
2401
2402
- // Check scalable flags
2403
- bool srcDimScalableFlag = srcVectorType.getScalableDims ()[r ];
2404
- bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + r ];
2402
+ // Check scalable flags.
2403
+ bool srcDimScalableFlag = srcVectorType.getScalableDims ()[dimIdx ];
2404
+ bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + dimIdx ];
2405
2405
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1 ) ||
2406
2406
(srcDimScalableFlag != dstDimScalableFlag))
2407
2407
mismatch = true ;
2408
2408
2409
2409
if (mismatch) {
2410
- if (mismatchingDims) {
2410
+ if (mismatchingDims != nullptr ) {
2411
2411
mismatchingDims->first .dim = srcDim;
2412
- mismatchingDims->first .scalableFlag = srcDimScalableFlag;
2412
+ mismatchingDims->first .isScalable = srcDimScalableFlag;
2413
2413
2414
2414
mismatchingDims->second .dim = dstDim;
2415
- mismatchingDims->second .scalableFlag = dstDimScalableFlag;
2415
+ mismatchingDims->second .isScalable = dstDimScalableFlag;
2416
2416
}
2417
2417
return BroadcastableToResult::DimensionMismatch;
2418
2418
}
@@ -2430,15 +2430,14 @@ LogicalResult BroadcastOp::verify() {
2430
2430
if (res == BroadcastableToResult::SourceRankHigher)
2431
2431
return emitOpError (" source rank higher than destination rank" );
2432
2432
if (res == BroadcastableToResult::DimensionMismatch) {
2433
- std::string msg =
2434
- (Twine (" dimension mismatch (" ) +
2435
- (mismatchingDims.first .scalableFlag ? " [" : " " ) +
2436
- std::to_string (mismatchingDims.first .dim ) +
2437
- (mismatchingDims.first .scalableFlag ? " ]" : " " ) + " vs. " +
2438
- (mismatchingDims.second .scalableFlag ? " [" : " " ) +
2439
- std::to_string (mismatchingDims.second .dim ) +
2440
- (mismatchingDims.second .scalableFlag ? " ]" : " " ) + " )" )
2441
- .str ();
2433
+ std::string msg = (Twine (" dimension mismatch (" ) +
2434
+ (mismatchingDims.first .isScalable ? " [" : " " ) +
2435
+ std::to_string (mismatchingDims.first .dim ) +
2436
+ (mismatchingDims.first .isScalable ? " ]" : " " ) + " vs. " +
2437
+ (mismatchingDims.second .isScalable ? " [" : " " ) +
2438
+ std::to_string (mismatchingDims.second .dim ) +
2439
+ (mismatchingDims.second .isScalable ? " ]" : " " ) + " )" )
2440
+ .str ();
2442
2441
return emitOpError (msg);
2443
2442
}
2444
2443
if (res == BroadcastableToResult::SourceTypeNotAVector)
0 commit comments