Skip to content

Commit 3c46deb

Browse files
authored
[MLIR] Fix 0-dimensional case of conversion of vector ops to GPU (#128075)
This is a follow-up to #127844. That PR got vectors of arbitrary rank working, but I hadn't thought about the rank-0 case. Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 4d92975 commit 3c46deb

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
624624
const LLVMTypeConverter &converter) {
625625
TypeRange operandTypes(operands);
626626
if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
627-
VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
627+
VectorType vectorType =
628+
cast<VectorType>(converter.convertType(op->getResultTypes()[0]));
628629
rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
629630
rewriter, converter));
630631
return success();

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,20 @@ module {
516516

517517
// -----
518518

519+
module @test_module {
520+
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
521+
// CHECK-LABEL: func @math_sin_vector_0d
522+
func.func @math_sin_vector_0d(%arg : vector<f16>) -> vector<f16> {
523+
// CHECK: llvm.extractelement {{.*}} : vector<1xf16>
524+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
525+
// CHECK: llvm.insertelement {{.*}} : vector<1xf16>
526+
%result = math.sin %arg : vector<f16>
527+
func.return %result : vector<f16>
528+
}
529+
}
530+
531+
// -----
532+
519533
module @test_module {
520534
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
521535
// CHECK-LABEL: func @math_sin_vector_1d

0 commit comments

Comments
 (0)