|
33 | 33 | #include "mlir/IR/OpImplementation.h"
|
34 | 34 | #include "mlir/IR/PatternMatch.h"
|
35 | 35 | #include "mlir/IR/TypeUtilities.h"
|
| 36 | +#include "mlir/IR/ValueRange.h" |
36 | 37 | #include "mlir/Interfaces/SubsetOpInterface.h"
|
37 | 38 | #include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
38 | 39 | #include "mlir/Support/LLVM.h"
|
@@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
|
2394 | 2395 | ///
|
2395 | 2396 | /// becomes
|
2396 | 2397 | /// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
|
2397 |
| -static LogicalResult |
2398 |
| -rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, |
2399 |
| - PatternRewriter &rewriter) { |
| 2398 | +/// |
| 2399 | +/// The requirements for this to be valid are |
| 2400 | +/// i) all elements are extracted from the same vector (source), |
| 2401 | +/// ii) source and from_elements result have the same number of elements, |
| 2402 | +/// iii) the elements are extracted in ascending order. |
| 2403 | +/// |
| 2404 | +/// It might be possible to rewrite vector.from_elements as a single |
| 2405 | +/// vector.extract if (ii) is not satisifed, or in some cases as a |
| 2406 | +/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied, |
| 2407 | +/// this is left for future consideration. |
| 2408 | +class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> { |
| 2409 | +public: |
| 2410 | + using OpRewritePattern::OpRewritePattern; |
2400 | 2411 |
|
2401 |
| - // The common source of vector.extract operations (if one exists), as well |
2402 |
| - // as its shape and rank. These are set in the first iteration of the loop |
2403 |
| - // over the operands (elements) of `fromElementsOp`. |
2404 |
| - Value source; |
2405 |
| - ArrayRef<int64_t> shape; |
2406 |
| - int64_t rank; |
| 2412 | + LogicalResult matchAndRewrite(FromElementsOp fromElements, |
| 2413 | + PatternRewriter &rewriter) const override { |
2407 | 2414 |
|
2408 |
| - for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) { |
| 2415 | + mlir::OperandRange elements = fromElements.getElements(); |
| 2416 | + assert(!elements.empty() && "must be at least 1 element"); |
2409 | 2417 |
|
2410 |
| - // Check that the element is defined by an extract operation, and that |
2411 |
| - // the extract is on the same vector as all preceding elements. |
2412 |
| - auto extractOp = |
2413 |
| - dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); |
2414 |
| - if (!extractOp) |
2415 |
| - return failure(); |
2416 |
| - Value currentSource = extractOp.getVector(); |
2417 |
| - if (index == 0) { |
2418 |
| - source = currentSource; |
2419 |
| - shape = extractOp.getSourceVectorType().getShape(); |
2420 |
| - rank = shape.size(); |
2421 |
| - } else if (currentSource != source) { |
2422 |
| - return failure(); |
| 2418 | + Value firstElement = elements.front(); |
| 2419 | + ExtractOp extractOp = |
| 2420 | + dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp()); |
| 2421 | + if (!extractOp) { |
| 2422 | + return rewriter.notifyMatchFailure( |
| 2423 | + fromElements, "first element not from vector.extract"); |
2423 | 2424 | }
|
| 2425 | + VectorType sourceType = extractOp.getSourceVectorType(); |
| 2426 | + Value source = extractOp.getVector(); |
2424 | 2427 |
|
2425 |
| - // Check that the (linearized) index of extraction is the same as the index |
2426 |
| - // in the result of `fromElementsOp`. |
2427 |
| - ArrayRef<int64_t> position = extractOp.getStaticPosition(); |
2428 |
| - assert(position.size() == rank && |
2429 |
| - "scalar extract must have full rank position"); |
2430 |
| - int64_t stride{1}; |
2431 |
| - int64_t offset{0}; |
2432 |
| - for (auto [pos, size] : |
2433 |
| - llvm::zip(llvm::reverse(position), llvm::reverse(shape))) { |
2434 |
| - if (pos == ShapedType::kDynamic) |
2435 |
| - return failure(); |
2436 |
| - offset += pos * stride; |
2437 |
| - stride *= size; |
| 2428 | + // Check condition (ii). |
| 2429 | + if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) { |
| 2430 | + return rewriter.notifyMatchFailure(fromElements, |
| 2431 | + "number of elements differ"); |
2438 | 2432 | }
|
2439 |
| - if (offset != index) |
2440 |
| - return failure(); |
2441 |
| - } |
2442 | 2433 |
|
2443 |
| - rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp, |
2444 |
| - fromElementsOp.getType(), source); |
2445 |
| -} |
| 2434 | + for (auto [indexMinusOne, element] : |
| 2435 | + llvm::enumerate(elements.drop_front(1))) { |
| 2436 | + |
| 2437 | + extractOp = |
| 2438 | + dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); |
| 2439 | + if (!extractOp) { |
| 2440 | + return rewriter.notifyMatchFailure(fromElements, |
| 2441 | + "element not from vector.extract"); |
| 2442 | + } |
| 2443 | + Value currentSource = extractOp.getVector(); |
| 2444 | + // Check condition (i). |
| 2445 | + if (currentSource != source) { |
| 2446 | + return rewriter.notifyMatchFailure(fromElements, |
| 2447 | + "element from different vector"); |
| 2448 | + } |
| 2449 | + |
| 2450 | + ArrayRef<int64_t> position = extractOp.getStaticPosition(); |
| 2451 | + assert(position.size() == static_cast<size_t>(sourceType.getRank()) && |
| 2452 | + "scalar extract must have full rank position"); |
| 2453 | + int64_t stride{1}; |
| 2454 | + int64_t offset{0}; |
| 2455 | + for (auto [pos, size] : llvm::zip(llvm::reverse(position), |
| 2456 | + llvm::reverse(sourceType.getShape()))) { |
| 2457 | + if (pos == ShapedType::kDynamic) { |
| 2458 | + return rewriter.notifyMatchFailure( |
| 2459 | + fromElements, "elements not in ascending order (dynamic order)"); |
| 2460 | + } |
| 2461 | + offset += pos * stride; |
| 2462 | + stride *= size; |
| 2463 | + } |
| 2464 | + // Check condition (iii). |
| 2465 | + if (offset != static_cast<int64_t>(indexMinusOne + 1)) { |
| 2466 | + return rewriter.notifyMatchFailure( |
| 2467 | + fromElements, "elements not in ascending order (static order)"); |
| 2468 | + } |
| 2469 | + } |
| 2470 | + |
| 2471 | + rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements, |
| 2472 | + fromElements.getType(), source); |
| 2473 | + return success(); |
| 2474 | + } |
| 2475 | +}; |
2446 | 2476 |
|
2447 | 2477 | void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
2448 | 2478 | MLIRContext *context) {
|
2449 | 2479 | results.add(rewriteFromElementsAsSplat);
|
2450 |
| - results.add(rewriteFromElementsAsShapeCast); |
| 2480 | + results.add<FromElementsToShapCast>(context); |
2451 | 2481 | }
|
2452 | 2482 |
|
2453 | 2483 | //===----------------------------------------------------------------------===//
|
|
0 commit comments