Skip to content

Commit d5f407e

Browse files
committed
llvm-array-unroll
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 1841bcd commit d5f407e

File tree

3 files changed

+78
-19
lines changed

3 files changed

+78
-19
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "GPUOpsLowering.h"
1010

1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12+
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/IR/Attributes.h"
1415
#include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
586587
return success();
587588
}
588589

589-
/// Unrolls op if it's operating on vectors.
590-
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
591-
ConversionPatternRewriter &rewriter,
592-
const LLVMTypeConverter &converter) {
590+
/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
591+
/// Used either directly (for ops on 1D vectors) or as the callback passed to
592+
/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
593+
static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
594+
Type llvm1DVectorTy,
595+
ConversionPatternRewriter &rewriter,
596+
const LLVMTypeConverter &converter) {
593597
TypeRange operandTypes(operands);
594-
if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
595-
return rewriter.notifyMatchFailure(op, "expected vector operand");
596-
}
597-
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
598-
return rewriter.notifyMatchFailure(op, "expected no region/successor");
599-
if (op->getNumResults() != 1)
600-
return rewriter.notifyMatchFailure(op, "expected single result");
601-
VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
602-
if (!vectorType)
603-
return rewriter.notifyMatchFailure(op, "expected vector result");
604-
598+
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
605599
Location loc = op->getLoc();
606600
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
607601
Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
621615
result = rewriter.create<LLVM::InsertElementOp>(
622616
loc, result, scalarOp->getResult(0), index);
623617
}
618+
return result;
619+
}
624620

625-
rewriter.replaceOp(op, result);
626-
return success();
621+
/// Unrolls op to array/vector elements.
622+
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
623+
ConversionPatternRewriter &rewriter,
624+
const LLVMTypeConverter &converter) {
625+
TypeRange operandTypes(operands);
626+
if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
627+
VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
628+
rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
629+
rewriter, converter));
630+
return success();
631+
}
632+
633+
if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
634+
return LLVM::detail::handleMultidimensionalVectors(
635+
op, operands, converter,
636+
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
637+
return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
638+
converter);
639+
},
640+
rewriter);
641+
}
642+
643+
return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
627644
}
628645

629646
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
172172
};
173173

174174
namespace impl {
175-
/// Unrolls op if it's operating on vectors.
175+
/// Unrolls op to array/vector elements.
176176
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
177177
ConversionPatternRewriter &rewriter,
178178
const LLVMTypeConverter &converter);
179179
} // namespace impl
180180

181-
/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
181+
/// Unrolls SourceOp to array/vector elements.
182182
template <typename SourceOp>
183183
struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
184184
public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
191191
*this->getTypeConverter());
192192
}
193193
};
194+
194195
} // namespace mlir
195196

196197
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,44 @@ module {
513513
"test.possible_terminator"() : () -> ()
514514
}) : () -> ()
515515
}
516+
517+
// -----
518+
519+
module @test_module {
520+
// CHECK-LABEL: func @math_sin_vector_1d
521+
func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
522+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
523+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
524+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
525+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
526+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
527+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
528+
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
529+
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
530+
%result = math.sin %arg : vector<4xf16>
531+
func.return %result : vector<4xf16>
532+
}
533+
}
534+
535+
// -----
536+
537+
module @test_module {
538+
// CHECK-LABEL: func @math_sin_vector_2d
539+
func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
540+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
541+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
542+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
543+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
544+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
545+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
546+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
547+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
548+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
549+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
550+
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
551+
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
552+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
553+
%result = math.sin %arg : vector<2x2xf16>
554+
func.return %result : vector<2x2xf16>
555+
}
556+
}

0 commit comments

Comments
 (0)