Skip to content

Commit 4d83429

Browse files
[mlir][GPU] Add gpu.assert op
1 parent 09c1daa commit 4d83429

File tree

7 files changed

+270
-63
lines changed

7 files changed

+270
-63
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
3838
class GPU_Op<string mnemonic, list<Trait> traits = []> :
3939
Op<GPU_Dialect, mnemonic, traits>;
4040

41+
def GPU_AssertOp : GPU_Op<"assert"> {
42+
let summary = "Device-side assertion";
43+
let description = [{
44+
The `gpu.assert` op is a device-side assertion. If the given `condition`
45+
is 0, the kernel execution is aborted, optionally with the given error
46+
message. This op is useful for debugging and verifying invariants.
47+
}];
48+
let arguments = (ins I1:$condition, OptionalAttr<StrAttr>:$message);
49+
let assemblyFormat = "$condition (`,` $message^)? attr-dict";
50+
}
51+
4152
def GPU_Dimension : I32EnumAttr<"Dimension",
4253
"a dimension, either 'x', 'y', or 'z'",
4354
[

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,59 @@
1919

2020
using namespace mlir;
2121

22+
LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
23+
Location loc, OpBuilder &b,
24+
StringRef name,
25+
LLVM::LLVMFunctionType type) {
26+
LLVM::LLVMFuncOp ret;
27+
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
28+
OpBuilder::InsertionGuard guard(b);
29+
b.setInsertionPointToStart(moduleOp.getBody());
30+
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
31+
}
32+
return ret;
33+
}
34+
35+
static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
36+
StringRef prefix) {
37+
// Get a unique global name.
38+
unsigned stringNumber = 0;
39+
SmallString<16> stringConstName;
40+
do {
41+
stringConstName.clear();
42+
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
43+
} while (moduleOp.lookupSymbol(stringConstName));
44+
return stringConstName;
45+
}
46+
47+
LLVM::GlobalOp
48+
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
49+
gpu::GPUModuleOp moduleOp, Type llvmI8,
50+
StringRef namePrefix, StringRef str,
51+
uint64_t alignment, unsigned addrSpace) {
52+
llvm::SmallString<20> nullTermStr(str);
53+
nullTermStr.push_back('\0'); // Null terminate for C
54+
auto globalType =
55+
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
56+
StringAttr attr = b.getStringAttr(nullTermStr);
57+
58+
// Try to find existing global.
59+
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
60+
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
61+
globalOp.getValueAttr() == attr &&
62+
globalOp.getAlignment().value_or(0) == alignment &&
63+
globalOp.getAddrSpace() == addrSpace)
64+
return globalOp;
65+
66+
// Not found: create new global.
67+
OpBuilder::InsertionGuard guard(b);
68+
b.setInsertionPointToStart(moduleOp.getBody());
69+
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
70+
return b.create<LLVM::GlobalOp>(loc, globalType,
71+
/*isConstant=*/true, LLVM::Linkage::Internal,
72+
name, attr, alignment, addrSpace);
73+
}
74+
2275
LogicalResult
2376
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2477
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
328381
return success();
329382
}
330383

331-
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
332-
const char formatStringPrefix[] = "printfFormat_";
333-
// Get a unique global name.
334-
unsigned stringNumber = 0;
335-
SmallString<16> stringConstName;
336-
do {
337-
stringConstName.clear();
338-
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
339-
} while (moduleOp.lookupSymbol(stringConstName));
340-
return stringConstName;
341-
}
342-
343-
/// Create an global that contains the given format string. If a global with
344-
/// the same format string exists already in the module, return that global.
345-
static LLVM::GlobalOp getOrCreateFormatStringConstant(
346-
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
347-
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
348-
llvm::SmallString<20> formatString(str);
349-
formatString.push_back('\0'); // Null terminate for C
350-
auto globalType =
351-
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
352-
StringAttr attr = b.getStringAttr(formatString);
353-
354-
// Try to find existing global.
355-
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
356-
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
357-
globalOp.getValueAttr() == attr &&
358-
globalOp.getAlignment().value_or(0) == alignment &&
359-
globalOp.getAddrSpace() == addrSpace)
360-
return globalOp;
361-
362-
// Not found: create new global.
363-
OpBuilder::InsertionGuard guard(b);
364-
b.setInsertionPointToStart(moduleOp.getBody());
365-
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
366-
return b.create<LLVM::GlobalOp>(loc, globalType,
367-
/*isConstant=*/true, LLVM::Linkage::Internal,
368-
name, attr, alignment, addrSpace);
369-
}
370-
371-
template <typename T>
372-
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
373-
ConversionPatternRewriter &rewriter,
374-
StringRef name,
375-
LLVM::LLVMFunctionType type) {
376-
LLVM::LLVMFuncOp ret;
377-
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
378-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
379-
rewriter.setInsertionPointToStart(moduleOp.getBody());
380-
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
381-
LLVM::Linkage::External);
382-
}
383-
return ret;
384-
}
385-
386384
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
387385
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
388386
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
420418
Value printfDesc = printfBeginCall.getResult();
421419

422420
// Create the global op or find an existing one.
423-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
424-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
421+
LLVM::GlobalOp global = getOrCreateStringConstant(
422+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
425423

426424
// Get a pointer to the format string's first element and pass it to printf()
427425
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
502500
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
503501

504502
// Create the global op or find an existing one.
505-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
506-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
507-
addressSpace);
503+
LLVM::GlobalOp global = getOrCreateStringConstant(
504+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
505+
/*alignment=*/0, addressSpace);
508506

509507
// Get a pointer to the format string's first element
510508
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
546544
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
547545

548546
// Create the global op or find an existing one.
549-
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
550-
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
547+
LLVM::GlobalOp global = getOrCreateStringConstant(
548+
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
551549

552550
// Get a pointer to the format string's first element
553551
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414

1515
namespace mlir {
1616

17+
//===----------------------------------------------------------------------===//
18+
// Helper Functions
19+
//===----------------------------------------------------------------------===//
20+
21+
/// Find or create an external function declaration in the given module.
22+
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
23+
OpBuilder &b, StringRef name,
24+
LLVM::LLVMFunctionType type);
25+
26+
/// Create a global that contains the given string. If a global with the same
27+
/// string already exists in the module, return that global.
28+
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
29+
gpu::GPUModuleOp moduleOp, Type llvmI8,
30+
StringRef namePrefix, StringRef str,
31+
uint64_t alignment = 0,
32+
unsigned addrSpace = 0);
33+
34+
//===----------------------------------------------------------------------===//
35+
// Lowering Patterns
36+
//===----------------------------------------------------------------------===//
37+
1738
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
1839
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
1940
/// a memref descriptor with these values and return it.

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2626
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2727
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
28+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2829
#include "mlir/Dialect/Func/IR/FuncOps.h"
2930
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3031
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,105 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
236237
}
237238
};
238239

240+
/// Lowering of gpu.assert into a conditional __assertfail.
241+
struct GPUAssertOpToAssertfailLowering
242+
: public ConvertOpToLLVMPattern<gpu::AssertOp> {
243+
using ConvertOpToLLVMPattern<gpu::AssertOp>::ConvertOpToLLVMPattern;
244+
245+
LogicalResult
246+
matchAndRewrite(gpu::AssertOp assertOp, gpu::AssertOpAdaptor adaptor,
247+
ConversionPatternRewriter &rewriter) const override {
248+
MLIRContext *ctx = rewriter.getContext();
249+
Location loc = assertOp.getLoc();
250+
Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
251+
Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
252+
Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
253+
Type ptrType = LLVM::LLVMPointerType::get(ctx);
254+
Type voidType = LLVM::LLVMVoidType::get(ctx);
255+
256+
// Find or create __assertfail function declaration.
257+
auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
258+
auto assertfailType = LLVM::LLVMFunctionType::get(
259+
voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
260+
LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
261+
moduleOp, loc, rewriter, "__assertfail", assertfailType);
262+
assertfailDecl.setPassthroughAttr(
263+
ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
264+
265+
// Split blocks and insert conditional branch.
266+
// ^before:
267+
// ...
268+
// cf.cond_br %condition, ^after, ^assert
269+
// ^assert:
270+
// gpu.assert
271+
// cf.br ^after
272+
// ^after:
273+
// ...
274+
Block *beforeBlock = assertOp->getBlock();
275+
Block *assertBlock =
276+
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
277+
Block *afterBlock =
278+
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
279+
rewriter.setInsertionPointToEnd(beforeBlock);
280+
rewriter.create<cf::CondBranchOp>(loc, adaptor.getCondition(), afterBlock,
281+
assertBlock);
282+
rewriter.setInsertionPointToEnd(assertBlock);
283+
rewriter.create<cf::BranchOp>(loc, afterBlock);
284+
285+
// Continue gpu.assert lowering.
286+
rewriter.setInsertionPoint(assertOp);
287+
288+
// Populate file name, file number and function name from the location of
289+
// the AssertOp.
290+
StringRef fileName = "(unknown)";
291+
StringRef funcName = "(unknown)";
292+
int32_t fileLine = 0;
293+
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
294+
fileName = fileLineColLoc.getFilename().strref();
295+
fileLine = fileLineColLoc.getStartLine();
296+
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
297+
funcName = nameLoc.getName().strref();
298+
if (auto fileLineColLoc =
299+
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
300+
fileName = fileLineColLoc.getFilename().strref();
301+
fileLine = fileLineColLoc.getStartLine();
302+
}
303+
}
304+
// Extract message string.
305+
StringRef message = "";
306+
if (assertOp.getMessage().has_value())
307+
message = *assertOp.getMessage();
308+
309+
// Create constants.
310+
auto getGlobal = [&](LLVM::GlobalOp global) {
311+
// Get a pointer to the format string's first element.
312+
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
313+
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
314+
global.getSymNameAttr());
315+
Value start =
316+
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
317+
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
318+
return start;
319+
};
320+
Value assertMessage = getGlobal(getOrCreateStringConstant(
321+
rewriter, loc, moduleOp, i8Type, "assert_message_", message));
322+
Value assertFile = getGlobal(getOrCreateStringConstant(
323+
rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
324+
Value assertFunc = getGlobal(getOrCreateStringConstant(
325+
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
326+
Value assertLine =
327+
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
328+
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
329+
330+
// Insert function call to __assertfail.
331+
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
332+
assertFunc, c1};
333+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
334+
arguments);
335+
return success();
336+
}
337+
};
338+
239339
/// Import the GPU Ops to NVVM Patterns.
240340
#include "GPUToNVVM.cpp.inc"
241341

@@ -358,7 +458,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
358458
using gpu::index_lowering::IndexKind;
359459
using gpu::index_lowering::IntrType;
360460
populateWithGenerated(patterns);
361-
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
461+
patterns.add<GPUPrintfOpToVPrintfLowering, GPUAssertOpToAssertfailLowering>(
462+
converter);
362463
patterns.add<
363464
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
364465
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
969969
}
970970
}
971971

972+
// CHECK-LABEL: gpu.module @test_module_51
973+
// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
974+
// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("within split at {{.*}}gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
975+
// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
976+
// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
977+
// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
978+
// CHECK: llvm.cond_br %[[cond]], ^[[assert_block:.*]], ^[[after_block:.*]]
979+
// CHECK: ^[[assert_block]]: // pred: ^bb0
980+
// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
981+
// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
982+
// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
983+
// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<74 x i8>
984+
// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
985+
// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i8>
986+
// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
987+
// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
988+
// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
989+
// CHECK: llvm.br ^[[after_block]]
990+
// CHECK: ^[[after_block]]:
991+
// CHECK: llvm.return
992+
// CHECK: }
993+
994+
gpu.module @test_module_51 {
995+
gpu.func @test_assert(%arg0: i1) kernel {
996+
gpu.assert %arg0, "assert message"
997+
gpu.return
998+
}
999+
}
1000+
9721001
module attributes {transform.with_named_sequence} {
9731002
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
9741003
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,12 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
500500
}
501501
return %2 : vector<4xi32>
502502
}
503+
504+
// CHECK-LABEL: func @test_assert(
505+
func.func @test_assert(%cond : i1) {
506+
// CHECK: gpu.assert %{{.*}}, "message"
507+
gpu.assert %cond, "message"
508+
// CHECK: gpu.assert %{{.*}}
509+
gpu.assert %cond
510+
return
511+
}

0 commit comments

Comments
 (0)