Skip to content

[mlir][GPU] Add NVVM-specific cf.assert lowering #120431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3928,6 +3928,7 @@ class FIRToLLVMLowering
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
pattern);
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
if (!isAMDGCN)
Expand Down
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
populateFuncToLLVMConversionPatterns(typeConverter, patterns);

// The only remaining operation to lower from the `toy` dialect, is the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
///
/// Note: This function does not populate the default cf.assert lowering. That
/// is because some platforms have a custom cf.assert lowering. The default
/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
void populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
Expand Down Expand Up @@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down Expand Up @@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
RewritePatternSet &patterns) const final {
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
}
};
} // namespace
Expand Down
122 changes: 60 additions & 62 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,59 @@

using namespace mlir;

LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
Location loc, OpBuilder &b,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
}
return ret;
}

static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}

LLVM::GlobalOp
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment, unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
StringAttr attr = b.getStringAttr(nullTermStr);

// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
globalOp.getAddrSpace() == addrSpace)
return globalOp;

// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}

LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}

static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}

/// Create an global that contains the given format string. If a global with
/// the same format string exists already in the module, return that global.
static LLVM::GlobalOp getOrCreateFormatStringConstant(
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
llvm::SmallString<20> formatString(str);
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
StringAttr attr = b.getStringAttr(formatString);

// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
globalOp.getAddrSpace() == addrSpace)
return globalOp;

// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}

template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}

LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();

// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());

// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
Expand Down Expand Up @@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);

// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
addressSpace);
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
/*alignment=*/0, addressSpace);

// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
Expand Down Expand Up @@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);

// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
LLVM::GlobalOp global = getOrCreateStringConstant(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());

// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@

namespace mlir {

//===----------------------------------------------------------------------===//
// Helper Functions
//===----------------------------------------------------------------------===//

/// Find or create an external function declaration in the given module.
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type);

/// Create a global that contains the given string. If a global with the same
/// string already exists in the module, return that global.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment = 0,
unsigned addrSpace = 0);

//===----------------------------------------------------------------------===//
// Lowering Patterns
//===----------------------------------------------------------------------===//

/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
/// a memref descriptor with these values and return it.
Expand Down
101 changes: 100 additions & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
Expand Down Expand Up @@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};

/// Lowering of cf.assert into a conditional __assertfail.
struct AssertOpToAssertfailLowering
: public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
Location loc = assertOp.getLoc();
Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
Type ptrType = LLVM::LLVMPointerType::get(ctx);
Type voidType = LLVM::LLVMVoidType::get(ctx);

// Find or create __assertfail function declaration.
auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
auto assertfailType = LLVM::LLVMFunctionType::get(
voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "__assertfail", assertfailType);
assertfailDecl.setPassthroughAttr(
ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));

// Split blocks and insert conditional branch.
// ^before:
// ...
// cf.cond_br %condition, ^after, ^assert
// ^assert:
// cf.assert
// cf.br ^after
// ^after:
// ...
Block *beforeBlock = assertOp->getBlock();
Block *assertBlock =
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
rewriter.create<cf::BranchOp>(loc, afterBlock);

// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);

// Populate file name, file number and function name from the location of
// the AssertOp.
StringRef fileName = "(unknown)";
StringRef funcName = "(unknown)";
int32_t fileLine = 0;
while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
loc = callSiteLoc.getCallee();
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
fileName = fileLineColLoc.getFilename().strref();
fileLine = fileLineColLoc.getStartLine();
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
funcName = nameLoc.getName().strref();
if (auto fileLineColLoc =
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
fileName = fileLineColLoc.getFilename().strref();
fileLine = fileLineColLoc.getStartLine();
}
}

// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
Value assertFile = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);

// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
assertFunc, c1};
rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
arguments);
return success();
}
};

/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"

Expand Down Expand Up @@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
arith::populateArithToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
cf::populateAssertToLLVMConversionPattern(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
Expand Down
Loading
Loading