Skip to content

Commit 394b6be

Browse files
committed
fix blindspot where source is larger
1 parent d4cd92a commit 394b6be

File tree

2 files changed

+84
-43
lines changed

2 files changed

+84
-43
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/IR/OpImplementation.h"
3434
#include "mlir/IR/PatternMatch.h"
3535
#include "mlir/IR/TypeUtilities.h"
36+
#include "mlir/IR/ValueRange.h"
3637
#include "mlir/Interfaces/SubsetOpInterface.h"
3738
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
3839
#include "mlir/Support/LLVM.h"
@@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
23942395
///
23952396
/// becomes
23962397
/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
2397-
static LogicalResult
2398-
rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
2399-
PatternRewriter &rewriter) {
2398+
///
2399+
/// The requirements for this to be valid are
2400+
/// i) all elements are extracted from the same vector (source),
2401+
/// ii) source and from_elements result have the same number of elements,
2402+
/// iii) the elements are extracted in ascending order.
2403+
///
2404+
/// It might be possible to rewrite vector.from_elements as a single
2405+
/// vector.extract if (ii) is not satisifed, or in some cases as a
2406+
/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
2407+
/// this is left for future consideration.
2408+
class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
2409+
public:
2410+
using OpRewritePattern::OpRewritePattern;
24002411

2401-
// The common source of vector.extract operations (if one exists), as well
2402-
// as its shape and rank. These are set in the first iteration of the loop
2403-
// over the operands (elements) of `fromElementsOp`.
2404-
Value source;
2405-
ArrayRef<int64_t> shape;
2406-
int64_t rank;
2412+
LogicalResult matchAndRewrite(FromElementsOp fromElements,
2413+
PatternRewriter &rewriter) const override {
24072414

2408-
for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
2415+
mlir::OperandRange elements = fromElements.getElements();
2416+
assert(!elements.empty() && "must be at least 1 element");
24092417

2410-
// Check that the element is defined by an extract operation, and that
2411-
// the extract is on the same vector as all preceding elements.
2412-
auto extractOp =
2413-
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2414-
if (!extractOp)
2415-
return failure();
2416-
Value currentSource = extractOp.getVector();
2417-
if (index == 0) {
2418-
source = currentSource;
2419-
shape = extractOp.getSourceVectorType().getShape();
2420-
rank = shape.size();
2421-
} else if (currentSource != source) {
2422-
return failure();
2418+
Value firstElement = elements.front();
2419+
ExtractOp extractOp =
2420+
dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
2421+
if (!extractOp) {
2422+
return rewriter.notifyMatchFailure(
2423+
fromElements, "first element not from vector.extract");
24232424
}
2425+
VectorType sourceType = extractOp.getSourceVectorType();
2426+
Value source = extractOp.getVector();
24242427

2425-
// Check that the (linearized) index of extraction is the same as the index
2426-
// in the result of `fromElementsOp`.
2427-
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2428-
assert(position.size() == rank &&
2429-
"scalar extract must have full rank position");
2430-
int64_t stride{1};
2431-
int64_t offset{0};
2432-
for (auto [pos, size] :
2433-
llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
2434-
if (pos == ShapedType::kDynamic)
2435-
return failure();
2436-
offset += pos * stride;
2437-
stride *= size;
2428+
// Check condition (ii).
2429+
if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
2430+
return rewriter.notifyMatchFailure(fromElements,
2431+
"number of elements differ");
24382432
}
2439-
if (offset != index)
2440-
return failure();
2441-
}
24422433

2443-
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
2444-
fromElementsOp.getType(), source);
2445-
}
2434+
for (auto [indexMinusOne, element] :
2435+
llvm::enumerate(elements.drop_front(1))) {
2436+
2437+
extractOp =
2438+
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2439+
if (!extractOp) {
2440+
return rewriter.notifyMatchFailure(fromElements,
2441+
"element not from vector.extract");
2442+
}
2443+
Value currentSource = extractOp.getVector();
2444+
// Check condition (i).
2445+
if (currentSource != source) {
2446+
return rewriter.notifyMatchFailure(fromElements,
2447+
"element from different vector");
2448+
}
2449+
2450+
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2451+
assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
2452+
"scalar extract must have full rank position");
2453+
int64_t stride{1};
2454+
int64_t offset{0};
2455+
for (auto [pos, size] : llvm::zip(llvm::reverse(position),
2456+
llvm::reverse(sourceType.getShape()))) {
2457+
if (pos == ShapedType::kDynamic) {
2458+
return rewriter.notifyMatchFailure(
2459+
fromElements, "elements not in ascending order (dynamic order)");
2460+
}
2461+
offset += pos * stride;
2462+
stride *= size;
2463+
}
2464+
// Check condition (iii).
2465+
if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
2466+
return rewriter.notifyMatchFailure(
2467+
fromElements, "elements not in ascending order (static order)");
2468+
}
2469+
}
2470+
2471+
rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
2472+
fromElements.getType(), source);
2473+
return success();
2474+
}
2475+
};
24462476

24472477
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
24482478
MLIRContext *context) {
24492479
results.add(rewriteFromElementsAsSplat);
2450-
results.add(rewriteFromElementsAsShapeCast);
2480+
results.add<FromElementsToShapCast>(context);
24512481
}
24522482

24532483
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,14 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
143143
%2 = vector.from_elements %0, %1 : vector<2xi8>
144144
return %2 : vector<2xi8>
145145
}
146+
147+
// -----
148+
149+
// CHECK-LABEL: func @negative_source_too_large(
150+
// CHECK-NOT: shape_cast
151+
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
152+
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
153+
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
154+
%2 = vector.from_elements %0, %1 : vector<2xi8>
155+
return %2 : vector<2xi8>
156+
}

0 commit comments

Comments
 (0)