-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU #127844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Benoit Jacob (bjacob) ChangesThere was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM:
That led to conversion failures when lowering This PR fixes this by reusing a shared utility already used in other conversions to LLVM: llvm-project/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp Lines 80 to 104 in 8337d01
Full diff: https://github.com/llvm/llvm-project/pull/127844.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..f7e653690f037 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
return success();
}
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+ Type llvm1DVectorTy,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
- if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
- return rewriter.notifyMatchFailure(op, "expected vector operand");
- }
- if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
- return rewriter.notifyMatchFailure(op, "expected no region/successor");
- if (op->getNumResults() != 1)
- return rewriter.notifyMatchFailure(op, "expected single result");
- VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
- if (!vectorType)
- return rewriter.notifyMatchFailure(op, "expected vector result");
-
+ VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
result = rewriter.create<LLVM::InsertElementOp>(
loc, result, scalarOp->getResult(0), index);
}
+ return result;
+}
- rewriter.replaceOp(op, result);
- return success();
+/// Unrolls SourceOp to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
+ TypeRange operandTypes(operands);
+ if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+ VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+ rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+ rewriter, converter));
+ return success();
+ }
+
+ if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+ return LLVM::detail::handleMultidimensionalVectors(
+ op, operands, converter,
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+ converter);
+ },
+ rewriter);
+ }
+
+ return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
}
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..fd6c23a69a9df 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,14 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
};
namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
} // namespace impl
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array elements (which may still be vectors) if it's
+/// operating on arrays.
template <typename SourceOp>
struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
@@ -191,6 +192,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
*this->getTypeConverter());
}
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..b6493ca9b32c3 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,44 @@ module {
"test.possible_terminator"() : () -> ()
}) : () -> ()
}
+
+// -----
+
+module @test_module {
+ // CHECK-LABEL: func @math_sin_vector_1d
+ func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ %result = math.sin %arg : vector<4xf16>
+ func.return %result : vector<4xf16>
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK-LABEL: func @math_sin_vector_2d
+ func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ %result = math.sin %arg : vector<2x2xf16>
+ func.return %result : vector<2x2xf16>
+ }
+}
|
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
8d208db
to
d5f407e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code change lgtm, minor comment about testing, you're clear to land it after addressing it or deciding not to
module @test_module { | ||
// CHECK-LABEL: func @math_sin_vector_1d | ||
func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> { | ||
// CHECK: llvm.extractelement {{.*}} : vector<4xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to check to see these actually getting passed in to a sin
call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
…27844) There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM: - The TypeConverter part was handling vectors of arbitrary rank, converting them to nests of `!llvm.array< ... >` with a vector at the inner-most dimension: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655 - The rewrite pattern part was not handling `llvm.array`: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596 That led to conversion failures when lowering `math` dialect ops on rank-2 vectors, as in the testcase being added in this PR. This PR fixes this by reusing a shared utility already used in other conversions to LLVM: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104 --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU llvm/llvm-project#127844 Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
…27844) There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM: - The TypeConverter part was handling vectors of arbitrary rank, converting them to nests of `!llvm.array< ... >` with a vector at the inner-most dimension: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655 - The rewrite pattern part was not handling `llvm.array`: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596 That led to conversion failures when lowering `math` dialect ops on rank-2 vectors, as in the testcase being added in this PR. This PR fixes this by reusing a shared utility already used in other conversions to LLVM: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104 --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU. llvm/llvm-project#127844 Previously cherry-picked in #20031 and then dropped in #20015. Needed by the just-merged #19969. Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM:
!llvm.array< ... >
with a vector at the inner-most dimension:llvm-project/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Lines 629 to 655 in 8337d01
llvm.array
:llvm-project/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Lines 594 to 596 in 8337d01
That led to conversion failures when lowering
math
dialect ops on rank-2 vectors, as in the testcase being added in this PR.This PR fixes this by reusing a shared utility already used in other conversions to LLVM:
llvm-project/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Lines 80 to 104 in 8337d01