Skip to content

Commit 86d9040

Browse files
matthias-springerpaulhuggett
authored andcommitted
[mlir][GPU] Add NVVM-specific cf.assert lowering (llvm#120431)
This commit add an NVIDIA-specific lowering of `cf.assert` to to `__assertfail`. Note: `getUniqueFormatGlobalName`, `getOrCreateFormatStringConstant` and `getOrDefineFunction` are moved to `GPUOpsLowering.h`, so that they can be reused.
1 parent 5436414 commit 86d9040

File tree

11 files changed

+258
-64
lines changed

11 files changed

+258
-64
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3928,6 +3928,7 @@ class FIRToLLVMLowering
39283928
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
39293929
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
39303930
pattern);
3931+
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
39313932
// Math operations that have not been converted yet must be converted
39323933
// to Libm.
39333934
if (!isAMDGCN)

mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
220220
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
221221
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
222222
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
223+
cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
223224
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
224225

225226
// The only remaining operation to lower from the `toy` dialect, is the

mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ namespace cf {
2929
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
3030
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
3131
/// references have to remain alive during the entire pattern lifetime.
32+
///
33+
/// Note: This function does not populate the default cf.assert lowering. That
34+
/// is because some platforms have a custom cf.assert lowering. The default
35+
/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
3236
void populateControlFlowToLLVMConversionPatterns(
3337
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
3438

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
215215
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
216216
// clang-format off
217217
patterns.add<
218-
AssertOpLowering,
219218
BranchOpLowering,
220219
CondBranchOpLowering,
221220
SwitchOpLowering>(converter);
@@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
258257
LLVMTypeConverter converter(ctx, options);
259258
RewritePatternSet patterns(ctx);
260259
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
260+
mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
261261

262262
if (failed(applyPartialConversion(getOperation(), target,
263263
std::move(patterns))))
@@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
286286
RewritePatternSet &patterns) const final {
287287
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
288288
patterns);
289+
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
289290
}
290291
};
291292
} // namespace

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,59 @@
1919

2020
using namespace mlir;
2121

22+
LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
23+
Location loc, OpBuilder &b,
24+
StringRef name,
25+
LLVM::LLVMFunctionType type) {
26+
LLVM::LLVMFuncOp ret;
27+
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
28+
OpBuilder::InsertionGuard guard(b);
29+
b.setInsertionPointToStart(moduleOp.getBody());
30+
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
31+
}
32+
return ret;
33+
}
34+
35+
static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
36+
StringRef prefix) {
37+
// Get a unique global name.
38+
unsigned stringNumber = 0;
39+
SmallString<16> stringConstName;
40+
do {
41+
stringConstName.clear();
42+
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
43+
} while (moduleOp.lookupSymbol(stringConstName));
44+
return stringConstName;
45+
}
46+
47+
LLVM::GlobalOp
48+
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
49+
gpu::GPUModuleOp moduleOp, Type llvmI8,
50+
StringRef namePrefix, StringRef str,
51+
uint64_t alignment, unsigned addrSpace) {
52+
llvm::SmallString<20> nullTermStr(str);
53+
nullTermStr.push_back('\0'); // Null terminate for C
54+
auto globalType =
55+
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
56+
StringAttr attr = b.getStringAttr(nullTermStr);
57+
58+
// Try to find existing global.
59+
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
60+
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
61+
globalOp.getValueAttr() == attr &&
62+
globalOp.getAlignment().value_or(0) == alignment &&
63+
globalOp.getAddrSpace() == addrSpace)
64+
return globalOp;
65+
66+
// Not found: create new global.
67+
OpBuilder::InsertionGuard guard(b);
68+
b.setInsertionPointToStart(moduleOp.getBody());
69+
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
70+
return b.create<LLVM::GlobalOp>(loc, globalType,
71+
/*isConstant=*/true, LLVM::Linkage::Internal,
72+
name, attr, alignment, addrSpace);
73+
}
74+
2275
LogicalResult
2376
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2477
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
328381
return success();
329382
}
330383

331-
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
332-
const char formatStringPrefix[] = "printfFormat_";
333-
// Get a unique global name.
334-
unsigned stringNumber = 0;
335-
SmallString<16> stringConstName;
336-
do {
337-
stringConstName.clear();
338-
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
339-
} while (moduleOp.lookupSymbol(stringConstName));
340-
return stringConstName;
341-
}
342-
343-
/// Create an global that contains the given format string. If a global with
344-
/// the same format string exists already in the module, return that global.
345-
static LLVM::GlobalOp getOrCreateFormatStringConstant(
346-
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
347-
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
348-
llvm::SmallString<20> formatString(str);
349-
formatString.push_back('\0'); // Null terminate for C
350-
auto globalType =
351-
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
352-
StringAttr attr = b.getStringAttr(formatString);
353-
354-
// Try to find existing global.
355-
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
356-
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
357-
globalOp.getValueAttr() == attr &&
358-
globalOp.getAlignment().value_or(0) == alignment &&
359-
globalOp.getAddrSpace() == addrSpace)
360-
return globalOp;
361-
362-
// Not found: create new global.
363-
OpBuilder::InsertionGuard guard(b);
364-
b.setInsertionPointToStart(moduleOp.getBody());
365-
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
366-
return b.create<LLVM::GlobalOp>(loc, globalType,
367-
/*isConstant=*/true, LLVM::Linkage::Internal,
368-
name, attr, alignment, addrSpace);
369-
}
370-
371-
template <typename T>
372-
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
373-
ConversionPatternRewriter &rewriter,
374-
StringRef name,
375-
LLVM::LLVMFunctionType type) {
376-
LLVM::LLVMFuncOp ret;
377-
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
378-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
379-
rewriter.setInsertionPointToStart(moduleOp.getBody());
380-
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
381-
LLVM::Linkage::External);
382-
}
383-
return ret;
384-
}
385-
386384
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
387385
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
388386
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
420418
Value printfDesc = printfBeginCall.getResult();
421419

422420
// Create the global op or find an existing one.
423-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
424-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
421+
LLVM::GlobalOp global = getOrCreateStringConstant(
422+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
425423

426424
// Get a pointer to the format string's first element and pass it to printf()
427425
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
502500
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
503501

504502
// Create the global op or find an existing one.
505-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
506-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
507-
addressSpace);
503+
LLVM::GlobalOp global = getOrCreateStringConstant(
504+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
505+
/*alignment=*/0, addressSpace);
508506

509507
// Get a pointer to the format string's first element
510508
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
546544
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
547545

548546
// Create the global op or find an existing one.
549-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
550-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
547+
LLVM::GlobalOp global = getOrCreateStringConstant(
548+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
551549

552550
// Get a pointer to the format string's first element
553551
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414

1515
namespace mlir {
1616

17+
//===----------------------------------------------------------------------===//
18+
// Helper Functions
19+
//===----------------------------------------------------------------------===//
20+
21+
/// Find or create an external function declaration in the given module.
22+
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
23+
OpBuilder &b, StringRef name,
24+
LLVM::LLVMFunctionType type);
25+
26+
/// Create a global that contains the given string. If a global with the same
27+
/// string already exists in the module, return that global.
28+
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
29+
gpu::GPUModuleOp moduleOp, Type llvmI8,
30+
StringRef namePrefix, StringRef str,
31+
uint64_t alignment = 0,
32+
unsigned addrSpace = 0);
33+
34+
//===----------------------------------------------------------------------===//
35+
// Lowering Patterns
36+
//===----------------------------------------------------------------------===//
37+
1738
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
1839
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
1940
/// a memref descriptor with these values and return it.

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2626
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2727
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
28+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2829
#include "mlir/Dialect/Func/IR/FuncOps.h"
2930
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3031
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
236237
}
237238
};
238239

240+
/// Lowering of cf.assert into a conditional __assertfail.
241+
struct AssertOpToAssertfailLowering
242+
: public ConvertOpToLLVMPattern<cf::AssertOp> {
243+
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
244+
245+
LogicalResult
246+
matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
247+
ConversionPatternRewriter &rewriter) const override {
248+
MLIRContext *ctx = rewriter.getContext();
249+
Location loc = assertOp.getLoc();
250+
Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
251+
Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
252+
Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
253+
Type ptrType = LLVM::LLVMPointerType::get(ctx);
254+
Type voidType = LLVM::LLVMVoidType::get(ctx);
255+
256+
// Find or create __assertfail function declaration.
257+
auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
258+
auto assertfailType = LLVM::LLVMFunctionType::get(
259+
voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
260+
LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
261+
moduleOp, loc, rewriter, "__assertfail", assertfailType);
262+
assertfailDecl.setPassthroughAttr(
263+
ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
264+
265+
// Split blocks and insert conditional branch.
266+
// ^before:
267+
// ...
268+
// cf.cond_br %condition, ^after, ^assert
269+
// ^assert:
270+
// cf.assert
271+
// cf.br ^after
272+
// ^after:
273+
// ...
274+
Block *beforeBlock = assertOp->getBlock();
275+
Block *assertBlock =
276+
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
277+
Block *afterBlock =
278+
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
279+
rewriter.setInsertionPointToEnd(beforeBlock);
280+
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
281+
assertBlock);
282+
rewriter.setInsertionPointToEnd(assertBlock);
283+
rewriter.create<cf::BranchOp>(loc, afterBlock);
284+
285+
// Continue cf.assert lowering.
286+
rewriter.setInsertionPoint(assertOp);
287+
288+
// Populate file name, file number and function name from the location of
289+
// the AssertOp.
290+
StringRef fileName = "(unknown)";
291+
StringRef funcName = "(unknown)";
292+
int32_t fileLine = 0;
293+
while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
294+
loc = callSiteLoc.getCallee();
295+
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
296+
fileName = fileLineColLoc.getFilename().strref();
297+
fileLine = fileLineColLoc.getStartLine();
298+
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
299+
funcName = nameLoc.getName().strref();
300+
if (auto fileLineColLoc =
301+
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
302+
fileName = fileLineColLoc.getFilename().strref();
303+
fileLine = fileLineColLoc.getStartLine();
304+
}
305+
}
306+
307+
// Create constants.
308+
auto getGlobal = [&](LLVM::GlobalOp global) {
309+
// Get a pointer to the format string's first element.
310+
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
311+
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
312+
global.getSymNameAttr());
313+
Value start =
314+
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
315+
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
316+
return start;
317+
};
318+
Value assertMessage = getGlobal(getOrCreateStringConstant(
319+
rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
320+
Value assertFile = getGlobal(getOrCreateStringConstant(
321+
rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
322+
Value assertFunc = getGlobal(getOrCreateStringConstant(
323+
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
324+
Value assertLine =
325+
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
326+
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
327+
328+
// Insert function call to __assertfail.
329+
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
330+
assertFunc, c1};
331+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
332+
arguments);
333+
return success();
334+
}
335+
};
336+
239337
/// Import the GPU Ops to NVVM Patterns.
240338
#include "GPUToNVVM.cpp.inc"
241339

@@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
358456
using gpu::index_lowering::IndexKind;
359457
using gpu::index_lowering::IntrType;
360458
populateWithGenerated(patterns);
361-
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
459+
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
460+
converter);
362461
patterns.add<
363462
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
364463
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass
296296
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
297297
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
298298
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
299+
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
299300
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
300301
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
301302
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
304304
LLVMTypeConverter converter(&getContext());
305305
arith::populateArithToLLVMConversionPatterns(converter, patterns);
306306
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
307+
cf::populateAssertToLLVMConversionPattern(converter, patterns);
307308
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
308309
populateFuncToLLVMConversionPatterns(converter, patterns);
309310
populateOpenMPToLLVMConversionPatterns(converter, patterns);

0 commit comments

Comments
 (0)