-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][GPU] Add NVVM-specific cf.assert
lowering
#120431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new operation to the GPU dialect: This commit also includes an NVIDIA-specific lowering to Note: Patch is 20.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120431.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 42a017db300af6..793d663c0322a5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -38,6 +38,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class GPU_Op<string mnemonic, list<Trait> traits = []> :
Op<GPU_Dialect, mnemonic, traits>;
+def GPU_AssertOp : GPU_Op<"assert"> {
+ let summary = "Device-side assertion";
+ let description = [{
+ The `gpu.assert` op is a device-side assertion. If the given `condition`
+ is 0, the kernel execution is aborted, optionally with the given error
+ message. This op is useful for debugging and verifying invariants.
+ }];
+ let arguments = (ins I1:$condition, OptionalAttr<StrAttr>:$message);
+ let assemblyFormat = "$condition (`,` $message^)? attr-dict";
+}
+
def GPU_Dimension : I32EnumAttr<"Dimension",
"a dimension, either 'x', 'y', or 'z'",
[
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0bb..544fc57949e24d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
using namespace mlir;
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+ Location loc, OpBuilder &b,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+ StringRef prefix) {
+ // Get a unique global name.
+ unsigned stringNumber = 0;
+ SmallString<16> stringConstName;
+ do {
+ stringConstName.clear();
+ (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+ } while (moduleOp.lookupSymbol(stringConstName));
+ return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment, unsigned addrSpace) {
+ llvm::SmallString<20> nullTermStr(str);
+ nullTermStr.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+ StringAttr attr = b.getStringAttr(nullTermStr);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
- const char formatStringPrefix[] = "printfFormat_";
- // Get a unique global name.
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
- return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
- OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
- llvm::SmallString<20> formatString(str);
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- StringAttr attr = b.getStringAttr(formatString);
-
- // Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
- if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
- globalOp.getValueAttr() == attr &&
- globalOp.getAlignment().value_or(0) == alignment &&
- globalOp.getAddrSpace() == addrSpace)
- return globalOp;
-
- // Not found: create new global.
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
- SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMFunctionType type) {
- LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
- LLVM::Linkage::External);
- }
- return ret;
-}
-
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
- addressSpace);
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+ /*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca36e..e73a74845d2b66 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
namespace mlir {
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
+ LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment = 0,
+ unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
/// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index b343cf71e3a2e7..f95f6ea0ff41c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,105 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
+/// Lowering of gpu.assert into a conditional __assertfail.
+struct GPUAssertOpToAssertfailLowering
+ : public ConvertOpToLLVMPattern<gpu::AssertOp> {
+ using ConvertOpToLLVMPattern<gpu::AssertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::AssertOp assertOp, gpu::AssertOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = assertOp.getLoc();
+ Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+ Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+ Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+ Type ptrType = LLVM::LLVMPointerType::get(ctx);
+ Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+ // Find or create __assertfail function declaration.
+ auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+ auto assertfailType = LLVM::LLVMFunctionType::get(
+ voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+ LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "__assertfail", assertfailType);
+ assertfailDecl.setPassthroughAttr(
+ ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+ // Split blocks and insert conditional branch.
+ // ^before:
+ // ...
+ // cf.cond_br %condition, ^after, ^assert
+ // ^assert:
+ // gpu.assert
+ // cf.br ^after
+ // ^after:
+ // ...
+ Block *beforeBlock = assertOp->getBlock();
+ Block *assertBlock =
+ rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+ Block *afterBlock =
+ rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+ rewriter.setInsertionPointToEnd(beforeBlock);
+ rewriter.create<cf::CondBranchOp>(loc, adaptor.getCondition(), afterBlock,
+ assertBlock);
+ rewriter.setInsertionPointToEnd(assertBlock);
+ rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+ // Continue gpu.assert lowering.
+ rewriter.setInsertionPoint(assertOp);
+
+ // Populate file name, file number and function name from the location of
+ // the AssertOp.
+ StringRef fileName = "(unknown)";
+ StringRef funcName = "(unknown)";
+ int32_t fileLine = 0;
+ if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+ funcName = nameLoc.getName().strref();
+ if (auto fileLineColLoc =
+ dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ }
+ }
+ // Extract message string.
+ StringRef message = "";
+ if (assertOp.getMessage().has_value())
+ message = *assertOp.getMessage();
+
+ // Create constants.
+ auto getGlobal = [&](LLVM::GlobalOp global) {
+ // Get a pointer to the format string's first element.
+ Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ global.getSymNameAttr());
+ Value start =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ return start;
+ };
+ Value assertMessage = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_message_", message));
+ Value assertFile = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+ Value assertFunc = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+ Value assertLine =
+ rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+ Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+ // Insert function call to __assertfail.
+ SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+ assertFunc, c1};
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+ arguments);
+ return success();
+ }
+};
+
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"
@@ -358,7 +458,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
- patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+ patterns.add<GPUPrintfOpToVPrintfLowering, GPUAssertOpToAssertfailLowering>(
+ converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index fb440756e0c1d5..20d7372eef85d5 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -73,7 +73,6 @@ void buildCommonPassPipeline(
//===----------------------------------------------------------------------===//
void buildGpuPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToNVVMPipelineOptions &options) {
- pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
ConvertGpuOpsToNVVMOpsOptions opt;
opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
opt.indexBitwidth = options.indexBitWidth;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc7e..7e50734350be2a 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
}
}
+// CHECK-LABEL: gpu.module @test_module_51
+// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("within split at mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+// CHECK: llvm.cond_br %[[cond]], ^[[assert_block:.*]], ^[[after_block:.*]]
+// CHECK: ^[[assert_block]]: // pred: ^bb0
+// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<74 x i8>
+// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i8>
+// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+// CHECK: llvm.br ^[[after_block]]
+// CHECK: ^[[after_block]]:
+// CHECK: llvm.return
+// CHECK: }
+
+gpu.module @test_module_51 {
+ gpu.func @test_assert(%arg0: i1) kernel {
+ gpu.assert %arg0, "assert message"
+ gpu.return
+ }
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c0ff2044b76c40..d74ff9482b911d 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -500,3 +500,12 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
}
return %2 : vector<4xi32>
}
+
+// CHECK-LABEL: func @test_assert(
+func.func @test_assert(%cond : i1) {
+ // CHECK: gpu.assert %{{.*}}, "message"
+ gpu.assert %cond, "message"
+ // CHECK: gpu.assert %{{.*}}
+ gpu.assert %cond
+ return
+}
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 00000000000000..40745f26af16cf
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/asser...
[truncated]
|
f8dd45f
to
f9f5866
Compare
Can we just use cf.assert ? |
f9f5866
to
4d83429
Compare
This is a good idea. But I first have to move a few things around because there is an existing pattern in |
4d83429
to
79ca017
Compare
gpu.assert
opcf.assert
lowering
Do not run `cf-to-llvm` as part of `func-to-llvm`. This commit fixes #70982. This commit changes the way how `func.func` ops are lowered to LLVM. Previously, the signature of the entire region (i.e., entry block and all other blocks in the `func.func` op) was converted as part of the `func.func` lowering pattern. Now, only the entry block is converted. The remaining block signatures are converted together with `cf.br` and `cf.cond_br` as part of `cf-to-llvm`. All unstructured control flow is not converted as part of a single pass (`cf-to-llvm`). `func-to-llvm` no longer deals with unstructured control flow. Also add more test cases for control flow dialect ops. Note: This PR is in preparation of #120431, which adds an additional GPU-specific lowering for `cf.assert`. This was a problem because `cf.assert` used to be converted as part of `func-to-llvm`. Note for LLVM integration: If you see failures, add `-convert-cf-to-llvm` to your pass pipeline.
1d2450d
to
d75a6ed
Compare
d75a6ed
to
8865986
Compare
8865986
to
86f4eff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing this. It's definitely nice to have assert support.
86f4eff
to
dbd81f3
Compare
This commit add an NVIDIA-specific lowering of `cf.assert` to to `__assertfail`. Note: `getUniqueFormatGlobalName`, `getOrCreateFormatStringConstant` and `getOrDefineFunction` are moved to `GPUOpsLowering.h`, so that they can be reused.
This commit add an NVIDIA-specific lowering of
cf.assert
to to__assertfail
.Note:
getUniqueFormatGlobalName
,getOrCreateFormatStringConstant
andgetOrDefineFunction
are moved toGPUOpsLowering.h
, so that they can be reused.