9
9
#include " GPUOpsLowering.h"
10
10
11
11
#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
12
+ #include " mlir/Conversion/LLVMCommon/VectorPattern.h"
12
13
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
13
14
#include " mlir/IR/Attributes.h"
14
15
#include " mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
586
587
return success ();
587
588
}
588
589
589
- // / Unrolls op if it's operating on vectors.
590
- LogicalResult impl::scalarizeVectorOp (Operation *op, ValueRange operands,
591
- ConversionPatternRewriter &rewriter,
592
- const LLVMTypeConverter &converter) {
590
+ // / Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
591
+ // / Used either directly (for ops on 1D vectors) or as the callback passed to
592
+ // / detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
593
+ static Value scalarizeVectorOpHelper (Operation *op, ValueRange operands,
594
+ Type llvm1DVectorTy,
595
+ ConversionPatternRewriter &rewriter,
596
+ const LLVMTypeConverter &converter) {
593
597
TypeRange operandTypes (operands);
594
- if (llvm::none_of (operandTypes, llvm::IsaPred<VectorType>)) {
595
- return rewriter.notifyMatchFailure (op, " expected vector operand" );
596
- }
597
- if (op->getNumRegions () != 0 || op->getNumSuccessors () != 0 )
598
- return rewriter.notifyMatchFailure (op, " expected no region/successor" );
599
- if (op->getNumResults () != 1 )
600
- return rewriter.notifyMatchFailure (op, " expected single result" );
601
- VectorType vectorType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
602
- if (!vectorType)
603
- return rewriter.notifyMatchFailure (op, " expected vector result" );
604
-
598
+ VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
605
599
Location loc = op->getLoc ();
606
600
Value result = rewriter.create <LLVM::PoisonOp>(loc, vectorType);
607
601
Type indexType = converter.convertType (rewriter.getIndexType ());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
621
615
result = rewriter.create <LLVM::InsertElementOp>(
622
616
loc, result, scalarOp->getResult (0 ), index);
623
617
}
618
+ return result;
619
+ }
624
620
625
- rewriter.replaceOp (op, result);
626
- return success ();
621
+ // / Unrolls op to array/vector elements.
622
+ LogicalResult impl::scalarizeVectorOp (Operation *op, ValueRange operands,
623
+ ConversionPatternRewriter &rewriter,
624
+ const LLVMTypeConverter &converter) {
625
+ TypeRange operandTypes (operands);
626
+ if (llvm::any_of (operandTypes, llvm::IsaPred<VectorType>)) {
627
+ VectorType vectorType = cast<VectorType>(op->getResultTypes ()[0 ]);
628
+ rewriter.replaceOp (op, scalarizeVectorOpHelper (op, operands, vectorType,
629
+ rewriter, converter));
630
+ return success ();
631
+ }
632
+
633
+ if (llvm::any_of (operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
634
+ return LLVM::detail::handleMultidimensionalVectors (
635
+ op, operands, converter,
636
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
637
+ return scalarizeVectorOpHelper (op, operands, llvm1DVectorTy, rewriter,
638
+ converter);
639
+ },
640
+ rewriter);
641
+ }
642
+
643
+ return rewriter.notifyMatchFailure (op, " no llvm.array or vector to unroll" );
627
644
}
628
645
629
646
static IntegerAttr wrapNumericMemorySpace (MLIRContext *ctx, unsigned space) {
0 commit comments