Skip to content

Commit e84f6b6

Browse files
authored
[mlir] Fix conflict of user defined reserved functions with internal prototypes (#123378)
On lowering from `memref` to LLVM, `malloc` and other intrinsic functions from `libc` will be declared in the current module. User's redefinition of these reserved functions will poison the internal analysis with wrong prototype. This patch adds assertion on the found function's type and reports if it mismatch with the intended type. Related to #120950 --------- Co-authored-by: Luohao Wang <Luohaothu@users.noreply.github.com>
1 parent 13dcc95 commit e84f6b6

File tree

11 files changed

+224
-137
lines changed

11 files changed

+224
-137
lines changed

mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ namespace LLVM {
2323
/// Generate IR that prints the given string to stdout.
2424
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
2525
/// have the signature void(char const*). The default function is `printString`.
26-
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
27-
StringRef symbolName, StringRef string,
28-
const LLVMTypeConverter &typeConverter,
29-
bool addNewline = true,
30-
std::optional<StringRef> runtimeFunctionName = {});
26+
LogicalResult createPrintStrCall(
27+
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
28+
StringRef string, const LLVMTypeConverter &typeConverter,
29+
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
3130
} // namespace LLVM
3231

3332
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Support/LLVM.h"
19-
#include <optional>
2019

2120
namespace mlir {
2221
class Location;
@@ -29,42 +28,47 @@ class ValueRange;
2928
namespace LLVM {
3029
class LLVMFuncOp;
3130

32-
/// Helper functions to lookup or create the declaration for commonly used
31+
/// Helper functions to look up or create the declaration for commonly used
3332
/// external C function calls. The list of functions provided here must be
3433
/// implemented separately (e.g. as part of a support runtime library or as part
3534
/// of the libc).
36-
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
37-
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
38-
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
39-
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40-
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
35+
/// Failure if an unexpected version of function is found.
36+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
37+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
38+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
39+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
41+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
4242
/// Declares a function to print a C-string.
4343
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
4444
/// have the signature void(char const*). The default function is `printString`.
45-
LLVM::LLVMFuncOp
45+
FailureOr<LLVM::LLVMFuncOp>
4646
lookupOrCreatePrintStringFn(Operation *moduleOp,
4747
std::optional<StringRef> runtimeFunctionName = {});
48-
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
49-
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
50-
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
51-
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52-
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53-
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
54-
Type indexType);
55-
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
56-
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
57-
Type indexType);
58-
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
59-
Type indexType);
60-
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
61-
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
62-
Type unrankedDescriptorType);
48+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
49+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
50+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
51+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
53+
Type indexType);
54+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
55+
Type indexType);
56+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
57+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
58+
Type indexType);
59+
FailureOr<LLVM::LLVMFuncOp>
60+
lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
61+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
62+
FailureOr<LLVM::LLVMFuncOp>
63+
lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
64+
Type unrankedDescriptorType);
6365

6466
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
65-
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
66-
ArrayRef<Type> paramTypes = {},
67-
Type resultType = {}, bool isVarArg = false);
67+
/// Return a failure if the FuncOp found has unexpected signature.
68+
FailureOr<LLVM::LLVMFuncOp>
69+
lookupOrCreateFn(Operation *moduleOp, StringRef name,
70+
ArrayRef<Type> paramTypes = {}, Type resultType = {},
71+
bool isVarArg = false, bool isReserved = false);
6872

6973
} // namespace LLVM
7074
} // namespace mlir

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
396396
// Allocate memory for the coroutine frame.
397397
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398398
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399+
if (failed(allocFuncOp))
400+
return failure();
399401
auto coroAlloc = rewriter.create<LLVM::CallOp>(
400-
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
402+
loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
401403

402404
// Begin a coroutine: @llvm.coro.begin.
403405
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
431433
// Free the memory.
432434
auto freeFuncOp =
433435
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
434-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
436+
if (failed(freeFuncOp))
437+
return failure();
438+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
435439
ValueRange(coroMem.getResult()));
436440

437441
return success();

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
6161

6262
// Failed block: Generate IR to print the message and call `abort`.
6363
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64-
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65-
*getTypeConverter(), /*addNewLine=*/false,
66-
/*runtimeFunctionName=*/"puts");
64+
auto createResult = LLVM::createPrintStrCall(
65+
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
66+
/*addNewLine=*/false,
67+
/*runtimeFunctionName=*/"puts");
68+
if (createResult.failed())
69+
return failure();
70+
6771
if (abortOnFailedAssert) {
6872
// Insert the `abort` declaration if necessary.
6973
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
276276

277277
// Find the malloc and free, or declare them if necessary.
278278
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
279-
LLVM::LLVMFuncOp freeFunc, mallocFunc;
280-
if (toDynamic)
279+
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
280+
if (toDynamic) {
281281
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
282-
if (!toDynamic)
282+
if (failed(mallocFunc))
283+
return failure();
284+
}
285+
if (!toDynamic) {
283286
freeFunc = LLVM::lookupOrCreateFreeFn(module);
287+
if (failed(freeFunc))
288+
return failure();
289+
}
284290

285291
unsigned unrankedMemrefPos = 0;
286292
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
293299
// Allocate memory, copy, and free the source if necessary.
294300
Value memory =
295301
toDynamic
296-
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
302+
? builder
303+
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
297304
.getResult()
298305
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
299306
IntegerType::get(getContext(), 8),
@@ -302,7 +309,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
302309
Value source = desc.memRefDescPtr(builder, loc);
303310
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
304311
if (!toDynamic)
305-
builder.create<LLVM::CallOp>(loc, freeFunc, source);
312+
builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
306313

307314
// Create a new descriptor. The same descriptor can be returned multiple
308315
// times, attempting to modify its pointer can lead to memory leaks

mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
2727
return uniqueName;
2828
}
2929

30-
void mlir::LLVM::createPrintStrCall(
30+
LogicalResult mlir::LLVM::createPrintStrCall(
3131
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
3232
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
3333
std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,11 @@ void mlir::LLVM::createPrintStrCall(
5959
SmallVector<LLVM::GEPArg> indices(1, 0);
6060
Value gep =
6161
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62-
Operation *printer =
62+
FailureOr<LLVM::LLVMFuncOp> printer =
6363
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
64-
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
65-
gep);
64+
if (failed(printer))
65+
return failure();
66+
builder.create<LLVM::CallOp>(loc, TypeRange(),
67+
SymbolRefAttr::get(printer.value()), gep);
68+
return success();
6669
}

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414

1515
using namespace mlir;
1616

17-
namespace {
18-
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
19-
Operation *module, Type indexType) {
17+
static FailureOr<LLVM::LLVMFuncOp>
18+
getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
19+
Type indexType) {
2020
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
2121
if (useGenericFn)
2222
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
2323

2424
return LLVM::lookupOrCreateMallocFn(module, indexType);
2525
}
2626

27-
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
28-
Operation *module, Type indexType) {
27+
static FailureOr<LLVM::LLVMFuncOp>
28+
getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
29+
Type indexType) {
2930
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
3031

3132
if (useGenericFn)
@@ -34,8 +35,6 @@ LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
3435
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
3536
}
3637

37-
} // end namespace
38-
3938
Value AllocationOpLLVMLowering::createAligned(
4039
ConversionPatternRewriter &rewriter, Location loc, Value input,
4140
Value alignment) {
@@ -80,10 +79,13 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
8079
<< " to integer address space "
8180
"failed. Consider adding memory space conversions.";
8281
}
83-
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
82+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
8483
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
8584
getIndexType());
86-
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
85+
if (failed(allocFuncOp))
86+
return std::make_tuple(Value(), Value());
87+
auto results =
88+
rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
8789

8890
Value allocatedPtr =
8991
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +148,13 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
146148
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
147149

148150
Type elementPtrType = this->getElementPtrType(memRefType);
149-
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
151+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
150152
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
151153
getIndexType());
154+
if (failed(allocFuncOp))
155+
return Value();
152156
auto results = rewriter.create<LLVM::CallOp>(
153-
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
157+
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
154158

155159
return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
156160
elementPtrType, *getTypeConverter());

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ using namespace mlir;
3838

3939
namespace {
4040

41-
bool isStaticStrideOrOffset(int64_t strideOrOffset) {
41+
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
4242
return !ShapedType::isDynamic(strideOrOffset);
4343
}
4444

45-
LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46-
ModuleOp module) {
45+
static FailureOr<LLVM::LLVMFuncOp>
46+
getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
4747
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
4848

4949
if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
220220
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221221
ConversionPatternRewriter &rewriter) const override {
222222
// Insert the `free` declaration if it is not already present.
223-
LLVM::LLVMFuncOp freeFunc =
223+
FailureOr<LLVM::LLVMFuncOp> freeFunc =
224224
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225+
if (failed(freeFunc))
226+
return failure();
225227
Value allocatedPtr;
226228
if (auto unrankedTy =
227229
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
236238
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
237239
.allocatedPtr(rewriter, op.getLoc());
238240
}
239-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
241+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242+
allocatedPtr);
240243
return success();
241244
}
242245
};
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
838841
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
839842
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
840843
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841-
rewriter.create<LLVM::CallOp>(loc, copyFn,
844+
if (failed(copyFn))
845+
return failure();
846+
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
842847
ValueRange{elemSize, sourcePtr, targetPtr});
843848

844849
// Restore stack used for descriptors

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,11 +1546,15 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15461546

15471547
auto punct = printOp.getPunctuation();
15481548
if (auto stringLiteral = printOp.getStringLiteral()) {
1549-
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550-
*stringLiteral, *getTypeConverter(),
1551-
/*addNewline=*/false);
1549+
auto createResult =
1550+
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1551+
*stringLiteral, *getTypeConverter(),
1552+
/*addNewline=*/false);
1553+
if (createResult.failed())
1554+
return failure();
1555+
15521556
} else if (punct != PrintPunctuation::NoPunctuation) {
1553-
emitCall(rewriter, printOp->getLoc(), [&] {
1557+
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
15541558
switch (punct) {
15551559
case PrintPunctuation::Close:
15561560
return LLVM::lookupOrCreatePrintCloseFn(parent);
@@ -1563,7 +1567,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15631567
default:
15641568
llvm_unreachable("unexpected punctuation");
15651569
}
1566-
}());
1570+
}();
1571+
if (failed(op))
1572+
return failure();
1573+
emitCall(rewriter, printOp->getLoc(), op.value());
15671574
}
15681575

15691576
rewriter.eraseOp(printOp);
@@ -1588,7 +1595,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15881595

15891596
// Make sure element type has runtime support.
15901597
PrintConversion conversion = PrintConversion::None;
1591-
Operation *printer;
1598+
FailureOr<Operation *> printer;
15921599
if (printType.isF32()) {
15931600
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
15941601
} else if (printType.isF64()) {
@@ -1631,6 +1638,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16311638
} else {
16321639
return failure();
16331640
}
1641+
if (failed(printer))
1642+
return failure();
16341643

16351644
switch (conversion) {
16361645
case PrintConversion::ZeroExt64:
@@ -1648,7 +1657,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16481657
case PrintConversion::None:
16491658
break;
16501659
}
1651-
emitCall(rewriter, loc, printer, value);
1660+
emitCall(rewriter, loc, printer.value(), value);
16521661
return success();
16531662
}
16541663

0 commit comments

Comments
 (0)