Skip to content

[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU #127844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 19, 2025

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Feb 19, 2025

There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM:

  • The TypeConverter part was handling vectors of arbitrary rank, converting them to nests of !llvm.array< ... > with a vector at the inner-most dimension:
    /// Convert an n-D vector type to an LLVM vector type:
    /// * 0-D `vector<T>` are converted to vector<1xT>
    /// * 1-D `vector<axT>` remains as is while,
    /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
    /// `!llvm.array<ax...array<jxvector<kxT>>>`.
    /// As LLVM supports arrays of scalable vectors, this method will also convert
    /// n-D scalable vectors provided that only the trailing dim is scalable.
    FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
    auto elementType = convertType(type.getElementType());
    if (!elementType)
    return {};
    if (type.getShape().empty())
    return VectorType::get({1}, elementType);
    Type vectorType = VectorType::get(type.getShape().back(), elementType,
    type.getScalableDims().back());
    assert(LLVM::isCompatibleVectorType(vectorType) &&
    "expected vector type compatible with the LLVM dialect");
    // For n-D vector types for which a _non-trailing_ dim is scalable,
    // return a failure. Supporting such cases would require LLVM
    // to support something akin "scalable arrays" of vectors.
    if (llvm::is_contained(type.getScalableDims().drop_back(), true))
    return failure();
    auto shape = type.getShape();
    for (int i = shape.size() - 2; i >= 0; --i)
    vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
    return vectorType;
    }
  • The rewrite pattern part was not handling llvm.array:
    if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
    return rewriter.notifyMatchFailure(op, "expected vector operand");
    }

That led to conversion failures when lowering math dialect ops on rank-2 vectors, as in the testcase being added in this PR.

This PR fixes this by reusing a shared utility already used in other conversions to LLVM:

LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (const auto &operand : llvm::enumerate(operands)) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, operand.value(), position));
}
Value newVal = createOperand(result1DVectorTy, extractedOperands);
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
}

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Benoit Jacob (bjacob)

Changes

There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM:

  • The TypeConverter part was handling vectors of arbitrary rank, converting them to nests of !llvm.array&lt; ... &gt; with a vector at the inner-most dimension:
    /// Convert an n-D vector type to an LLVM vector type:
    /// * 0-D `vector<T>` are converted to vector<1xT>
    /// * 1-D `vector<axT>` remains as is while,
    /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
    /// `!llvm.array<ax...array<jxvector<kxT>>>`.
    /// As LLVM supports arrays of scalable vectors, this method will also convert
    /// n-D scalable vectors provided that only the trailing dim is scalable.
    FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
    auto elementType = convertType(type.getElementType());
    if (!elementType)
    return {};
    if (type.getShape().empty())
    return VectorType::get({1}, elementType);
    Type vectorType = VectorType::get(type.getShape().back(), elementType,
    type.getScalableDims().back());
    assert(LLVM::isCompatibleVectorType(vectorType) &&
    "expected vector type compatible with the LLVM dialect");
    // For n-D vector types for which a _non-trailing_ dim is scalable,
    // return a failure. Supporting such cases would require LLVM
    // to support something akin "scalable arrays" of vectors.
    if (llvm::is_contained(type.getScalableDims().drop_back(), true))
    return failure();
    auto shape = type.getShape();
    for (int i = shape.size() - 2; i >= 0; --i)
    vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
    return vectorType;
    }
  • The rewrite pattern part was not handling llvm.array:
    if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
    return rewriter.notifyMatchFailure(op, "expected vector operand");
    }

That led to conversion failures when lowering math dialect ops on rank-2 vectors, as in the testcase being added in this PR.

This PR fixes this by reusing a shared utility already used in other conversions to LLVM:

LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (const auto &operand : llvm::enumerate(operands)) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, operand.value(), position));
}
Value newVal = createOperand(result1DVectorTy, extractedOperands);
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
}


Full diff: https://github.com/llvm/llvm-project/pull/127844.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+34-17)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+4-2)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+41)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..f7e653690f037 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   return success();
 }
 
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
-                                      ConversionPatternRewriter &rewriter,
-                                      const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+                                     Type llvm1DVectorTy,
+                                     ConversionPatternRewriter &rewriter,
+                                     const LLVMTypeConverter &converter) {
   TypeRange operandTypes(operands);
-  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
-    return rewriter.notifyMatchFailure(op, "expected vector operand");
-  }
-  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
-    return rewriter.notifyMatchFailure(op, "expected no region/successor");
-  if (op->getNumResults() != 1)
-    return rewriter.notifyMatchFailure(op, "expected single result");
-  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
-  if (!vectorType)
-    return rewriter.notifyMatchFailure(op, "expected vector result");
-
+  VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
   Location loc = op->getLoc();
   Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
   Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
     result = rewriter.create<LLVM::InsertElementOp>(
         loc, result, scalarOp->getResult(0), index);
   }
+  return result;
+}
 
-  rewriter.replaceOp(op, result);
-  return success();
+/// Unrolls SourceOp to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      const LLVMTypeConverter &converter) {
+  TypeRange operandTypes(operands);
+  if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+    VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+    rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+                                                   rewriter, converter));
+    return success();
+  }
+
+  if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+    return LLVM::detail::handleMultidimensionalVectors(
+        op, operands, converter,
+        [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+          return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+                                         converter);
+        },
+        rewriter);
+  }
+
+  return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
 }
 
 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..fd6c23a69a9df 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,14 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
 };
 
 namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
 LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
                                 ConversionPatternRewriter &rewriter,
                                 const LLVMTypeConverter &converter);
 } // namespace impl
 
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array elements (which may still be vectors) if it's
+/// operating on arrays.
 template <typename SourceOp>
 struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
@@ -191,6 +192,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
                                    *this->getTypeConverter());
   }
 };
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..b6493ca9b32c3 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,44 @@ module {
     "test.possible_terminator"() : () -> ()
   }) : () -> ()
 }
+
+// -----
+
+module @test_module {
+  // CHECK-LABEL: func @math_sin_vector_1d
+  func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    %result = math.sin %arg : vector<4xf16>
+    func.return %result : vector<4xf16>
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK-LABEL: func @math_sin_vector_2d
+  func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+    // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>    
+    %result = math.sin %arg : vector<2x2xf16>
+    func.return %result : vector<2x2xf16>
+  }
+}

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code change lgtm, minor comment about testing, you're clear to land it after addressing it or deciding not to

module @test_module {
// CHECK-LABEL: func @math_sin_vector_1d
func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to check to see these actually getting passed in to a sin call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
@bjacob bjacob merged commit 4a411eb into llvm:main Feb 19, 2025
5 of 6 checks passed
bjacob added a commit to iree-org/llvm-project that referenced this pull request Feb 19, 2025
…27844)

There was a discrepancy between the type-converter and rewrite-pattern
parts of conversion to LLVM used in various GPU targets, at least ROCDL
and NVVM:
- The TypeConverter part was handling vectors of arbitrary rank,
converting them to nests of `!llvm.array< ... >` with a vector at the
inner-most dimension:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655
- The rewrite pattern part was not handling `llvm.array`:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596

That led to conversion failures when lowering `math` dialect ops on
rank-2 vectors, as in the testcase being added in this PR.

This PR fixes this by reusing a shared utility already used in other
conversions to LLVM:

https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
bjacob added a commit to iree-org/iree that referenced this pull request Feb 19, 2025
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU

llvm/llvm-project#127844

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
bjacob added a commit to iree-org/llvm-project that referenced this pull request Feb 19, 2025
…27844)

There was a discrepancy between the type-converter and rewrite-pattern
parts of conversion to LLVM used in various GPU targets, at least ROCDL
and NVVM:
- The TypeConverter part was handling vectors of arbitrary rank,
converting them to nests of `!llvm.array< ... >` with a vector at the
inner-most dimension:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655
- The rewrite pattern part was not handling `llvm.array`:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596

That led to conversion failures when lowering `math` dialect ops on
rank-2 vectors, as in the testcase being added in this PR.

This PR fixes this by reusing a shared utility already used in other
conversions to LLVM:

https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
bjacob added a commit to iree-org/iree that referenced this pull request Feb 19, 2025
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU.

llvm/llvm-project#127844

Previously cherry-picked in #20031
and then dropped in #20015.

Needed by the just-merged #19969.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
bjacob added a commit that referenced this pull request Feb 21, 2025
…8075)

This is a follow-up to #127844. That PR got vectors of arbitrary rank
working, but I hadn't thought about the rank-0 case.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants