Skip to content

Commit 4a411eb

Browse files
authored
[MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (#127844)
There was a discrepancy between the type-converter and rewrite-pattern parts of conversion to LLVM used in various GPU targets, at least ROCDL and NVVM: - The TypeConverter part was handling vectors of arbitrary rank, converting them to nests of `!llvm.array< ... >` with a vector at the inner-most dimension: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655 - The rewrite pattern part was not handling `llvm.array`: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596 That led to conversion failures when lowering `math` dialect ops on rank-2 vectors, as in the testcase being added in this PR. This PR fixes this by reusing a shared utility already used in other conversions to LLVM: https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104 --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent dca7306 commit 4a411eb

File tree

3 files changed

+88
-19
lines changed

3 files changed

+88
-19
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "GPUOpsLowering.h"
1010

1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12+
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/IR/Attributes.h"
1415
#include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
586587
return success();
587588
}
588589

589-
/// Unrolls op if it's operating on vectors.
590-
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
591-
ConversionPatternRewriter &rewriter,
592-
const LLVMTypeConverter &converter) {
590+
/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
591+
/// Used either directly (for ops on 1D vectors) or as the callback passed to
592+
/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
593+
static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
594+
Type llvm1DVectorTy,
595+
ConversionPatternRewriter &rewriter,
596+
const LLVMTypeConverter &converter) {
593597
TypeRange operandTypes(operands);
594-
if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
595-
return rewriter.notifyMatchFailure(op, "expected vector operand");
596-
}
597-
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
598-
return rewriter.notifyMatchFailure(op, "expected no region/successor");
599-
if (op->getNumResults() != 1)
600-
return rewriter.notifyMatchFailure(op, "expected single result");
601-
VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
602-
if (!vectorType)
603-
return rewriter.notifyMatchFailure(op, "expected vector result");
604-
598+
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
605599
Location loc = op->getLoc();
606600
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
607601
Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
621615
result = rewriter.create<LLVM::InsertElementOp>(
622616
loc, result, scalarOp->getResult(0), index);
623617
}
618+
return result;
619+
}
624620

625-
rewriter.replaceOp(op, result);
626-
return success();
621+
/// Unrolls op to array/vector elements.
622+
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
623+
ConversionPatternRewriter &rewriter,
624+
const LLVMTypeConverter &converter) {
625+
TypeRange operandTypes(operands);
626+
if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
627+
VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
628+
rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
629+
rewriter, converter));
630+
return success();
631+
}
632+
633+
if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
634+
return LLVM::detail::handleMultidimensionalVectors(
635+
op, operands, converter,
636+
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
637+
return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
638+
converter);
639+
},
640+
rewriter);
641+
}
642+
643+
return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
627644
}
628645

629646
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
172172
};
173173

174174
namespace impl {
175-
/// Unrolls op if it's operating on vectors.
175+
/// Unrolls op to array/vector elements.
176176
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
177177
ConversionPatternRewriter &rewriter,
178178
const LLVMTypeConverter &converter);
179179
} // namespace impl
180180

181-
/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
181+
/// Unrolls SourceOp to array/vector elements.
182182
template <typename SourceOp>
183183
struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
184184
public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
191191
*this->getTypeConverter());
192192
}
193193
};
194+
194195
} // namespace mlir
195196

196197
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,54 @@ module {
513513
"test.possible_terminator"() : () -> ()
514514
}) : () -> ()
515515
}
516+
517+
// -----
518+
519+
module @test_module {
520+
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
521+
// CHECK-LABEL: func @math_sin_vector_1d
522+
func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
523+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
524+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
525+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
526+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
527+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
528+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
529+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
530+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
531+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
532+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
533+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
534+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
535+
%result = math.sin %arg : vector<4xf16>
536+
func.return %result : vector<4xf16>
537+
}
538+
}
539+
540+
// -----
541+
542+
module @test_module {
543+
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
544+
// CHECK-LABEL: func @math_sin_vector_2d
545+
func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
546+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
547+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
548+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
549+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
550+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
551+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
552+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
553+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
554+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
555+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
556+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
557+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
558+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
559+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
560+
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
561+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
562+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
563+
%result = math.sin %arg : vector<2x2xf16>
564+
func.return %result : vector<2x2xf16>
565+
}
566+
}

0 commit comments

Comments
 (0)