Skip to content

[mlir][GPU] Add NVVM-specific cf.assert lowering #120431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 6, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 18, 2024

This commit add an NVIDIA-specific lowering of cf.assert to to __assertfail.

Note: getUniqueFormatGlobalName, getOrCreateFormatStringConstant and getOrDefineFunction are moved to GPUOpsLowering.h, so that they can be reused.

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new operation to the GPU dialect: gpu.assert. This op is a device-side assertion that can be used for debugging and/or verifying invariants.

This commit also includes an NVIDIA-specific lowering to __assertfail.

Note: getUniqueFormatGlobalName, getOrCreateFormatStringConstant and getOrDefineFunction are moved to GPUOpsLowering.h, so that they can be reused.


Patch is 20.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120431.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+11)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+60-62)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+21)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+102-1)
  • (modified) mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp (-1)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+29)
  • (modified) mlir/test/Dialect/GPU/ops.mlir (+9)
  • (added) mlir/test/Integration/GPU/CUDA/assert.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 42a017db300af6..793d663c0322a5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -38,6 +38,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 class GPU_Op<string mnemonic, list<Trait> traits = []> :
     Op<GPU_Dialect, mnemonic, traits>;
 
+def GPU_AssertOp : GPU_Op<"assert"> {
+  let summary = "Device-side assertion";
+  let description = [{
+    The `gpu.assert` op is a device-side assertion. If the given `condition`
+    is 0, the kernel execution is aborted, optionally with the given error
+    message. This op is useful for debugging and verifying invariants.
+  }];
+  let arguments = (ins I1:$condition, OptionalAttr<StrAttr>:$message);
+  let assemblyFormat = "$condition (`,` $message^)? attr-dict";
+}
+
 def GPU_Dimension : I32EnumAttr<"Dimension",
     "a dimension, either 'x', 'y', or 'z'",
     [
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0bb..544fc57949e24d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
 
 using namespace mlir;
 
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+                                           Location loc, OpBuilder &b,
+                                           StringRef name,
+                                           LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(moduleOp.getBody());
+    ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+                                           StringRef prefix) {
+  // Get a unique global name.
+  unsigned stringNumber = 0;
+  SmallString<16> stringConstName;
+  do {
+    stringConstName.clear();
+    (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+  } while (moduleOp.lookupSymbol(stringConstName));
+  return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                StringRef namePrefix, StringRef str,
+                                uint64_t alignment, unsigned addrSpace) {
+  llvm::SmallString<20> nullTermStr(str);
+  nullTermStr.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+  StringAttr attr = b.getStringAttr(nullTermStr);
+
+  // Try to find existing global.
+  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+        globalOp.getValueAttr() == attr &&
+        globalOp.getAlignment().value_or(0) == alignment &&
+        globalOp.getAddrSpace() == addrSpace)
+      return globalOp;
+
+  // Not found: create new global.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(moduleOp.getBody());
+  SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+  return b.create<LLVM::GlobalOp>(loc, globalType,
+                                  /*isConstant=*/true, LLVM::Linkage::Internal,
+                                  name, attr, alignment, addrSpace);
+}
+
 LogicalResult
 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   return success();
 }
 
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
-  const char formatStringPrefix[] = "printfFormat_";
-  // Get a unique global name.
-  unsigned stringNumber = 0;
-  SmallString<16> stringConstName;
-  do {
-    stringConstName.clear();
-    (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
-  } while (moduleOp.lookupSymbol(stringConstName));
-  return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
-    OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
-    StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
-  llvm::SmallString<20> formatString(str);
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  StringAttr attr = b.getStringAttr(formatString);
-
-  // Try to find existing global.
-  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
-    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
-        globalOp.getValueAttr() == attr &&
-        globalOp.getAlignment().value_or(0) == alignment &&
-        globalOp.getAddrSpace() == addrSpace)
-      return globalOp;
-
-  // Not found: create new global.
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointToStart(moduleOp.getBody());
-  SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
-  return b.create<LLVM::GlobalOp>(loc, globalType,
-                                  /*isConstant=*/true, LLVM::Linkage::Internal,
-                                  name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
-                                            ConversionPatternRewriter &rewriter,
-                                            StringRef name,
-                                            LLVM::LLVMFunctionType type) {
-  LLVM::LLVMFuncOp ret;
-  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
-                                            LLVM::Linkage::External);
-  }
-  return ret;
-}
-
 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   Value printfDesc = printfBeginCall.getResult();
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element and pass it to printf()
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
-      addressSpace);
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+      /*alignment=*/0, addressSpace);
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca36e..e73a74845d2b66 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+                                     OpBuilder &b, StringRef name,
+                                     LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                         gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                         StringRef namePrefix, StringRef str,
+                                         uint64_t alignment = 0,
+                                         unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
 /// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
 /// create a 0-sized global array symbol similar as LLVM expects. It constructs
 /// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index b343cf71e3a2e7..f95f6ea0ff41c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,105 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+/// Lowering of gpu.assert into a conditional __assertfail.
+struct GPUAssertOpToAssertfailLowering
+    : public ConvertOpToLLVMPattern<gpu::AssertOp> {
+  using ConvertOpToLLVMPattern<gpu::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::AssertOp assertOp, gpu::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = assertOp.getLoc();
+    Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+    Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+    Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+    Type ptrType = LLVM::LLVMPointerType::get(ctx);
+    Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+    // Find or create __assertfail function declaration.
+    auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+    auto assertfailType = LLVM::LLVMFunctionType::get(
+        voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+    LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "__assertfail", assertfailType);
+    assertfailDecl.setPassthroughAttr(
+        ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   gpu.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getCondition(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue gpu.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+      fileName = fileLineColLoc.getFilename().strref();
+      fileLine = fileLineColLoc.getStartLine();
+    } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+      funcName = nameLoc.getName().strref();
+      if (auto fileLineColLoc =
+              dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+        fileName = fileLineColLoc.getFilename().strref();
+        fileLine = fileLineColLoc.getStartLine();
+      }
+    }
+    // Extract message string.
+    StringRef message = "";
+    if (assertOp.getMessage().has_value())
+      message = *assertOp.getMessage();
+
+    // Create constants.
+    auto getGlobal = [&](LLVM::GlobalOp global) {
+      // Get a pointer to the format string's first element.
+      Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+          loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+          global.getSymNameAttr());
+      Value start =
+          rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                       globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      return start;
+    };
+    Value assertMessage = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_message_", message));
+    Value assertFile = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+    Value assertFunc = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+    Value assertLine =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+    Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+    // Insert function call to __assertfail.
+    SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+                                 assertFunc, c1};
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+                                              arguments);
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -358,7 +458,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
   populateWithGenerated(patterns);
-  patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+  patterns.add<GPUPrintfOpToVPrintfLowering, GPUAssertOpToAssertfailLowering>(
+      converter);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index fb440756e0c1d5..20d7372eef85d5 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -73,7 +73,6 @@ void buildCommonPassPipeline(
 //===----------------------------------------------------------------------===//
 void buildGpuPassPipeline(OpPassManager &pm,
                           const mlir::gpu::GPUToNVVMPipelineOptions &options) {
-  pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
   ConvertGpuOpsToNVVMOpsOptions opt;
   opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
   opt.indexBitwidth = options.indexBitWidth;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc7e..7e50734350be2a 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
   }
 }
 
+// CHECK-LABEL: gpu.module @test_module_51
+//       CHECK:   llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[file_name:.*]]("within split at mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+//       CHECK:   llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+//       CHECK:     llvm.cond_br %[[cond]], ^[[assert_block:.*]], ^[[after_block:.*]]
+//       CHECK:   ^[[assert_block]]:  // pred: ^bb0
+//       CHECK:     %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+//       CHECK:     %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+//       CHECK:     %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+//       CHECK:     %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<74 x i8>
+//       CHECK:     %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+//       CHECK:     %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i8>
+//       CHECK:     %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+//       CHECK:     %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:     llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+//       CHECK:     llvm.br ^[[after_block]]
+//       CHECK:   ^[[after_block]]:
+//       CHECK:     llvm.return
+//       CHECK:   }
+
+gpu.module @test_module_51 {
+  gpu.func @test_assert(%arg0: i1) kernel {
+    gpu.assert %arg0, "assert message"
+    gpu.return
+  }
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c0ff2044b76c40..d74ff9482b911d 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -500,3 +500,12 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
   }
   return %2 : vector<4xi32>
 }
+
+// CHECK-LABEL: func @test_assert(
+func.func @test_assert(%cond : i1) {
+  // CHECK: gpu.assert %{{.*}}, "message"
+  gpu.assert %cond, "message"
+  // CHECK: gpu.assert %{{.*}}
+  gpu.assert %cond
+  return
+}
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 00000000000000..40745f26af16cf
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/asser...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch from f8dd45f to f9f5866 Compare December 18, 2024 14:54
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/do_not_strip_debug December 18, 2024 14:54
@joker-eph
Copy link
Collaborator

Can we just use cf.assert ?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch from f9f5866 to 4d83429 Compare December 18, 2024 15:43
@matthias-springer
Copy link
Member Author

matthias-springer commented Dec 18, 2024

Can we just use cf.assert ?

This is a good idea. But I first have to move a few things around because there is an existing pattern in populateControlFlowToLLVMConversionPatterns... In particular, #70982 must be resolved first. Marking this PR as "draft" until then.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch from 4d83429 to 79ca017 Compare December 18, 2024 16:10
@matthias-springer matthias-springer marked this pull request as draft December 18, 2024 16:10
@matthias-springer matthias-springer changed the title [mlir][GPU] Add gpu.assert op [mlir][GPU] Add NVVM-specific cf.assert lowering Dec 18, 2024
Base automatically changed from users/matthias-springer/do_not_strip_debug to main December 19, 2024 14:05
matthias-springer added a commit that referenced this pull request Dec 20, 2024
Do not run `cf-to-llvm` as part of `func-to-llvm`. This commit fixes
#70982.

This commit changes the way how `func.func` ops are lowered to LLVM.
Previously, the signature of the entire region (i.e., entry block and
all other blocks in the `func.func` op) was converted as part of the
`func.func` lowering pattern.

Now, only the entry block is converted. The remaining block signatures
are converted together with `cf.br` and `cf.cond_br` as part of
`cf-to-llvm`. All unstructured control flow is not converted as part of
a single pass (`cf-to-llvm`). `func-to-llvm` no longer deals with
unstructured control flow.

Also add more test cases for control flow dialect ops.

Note: This PR is in preparation of #120431, which adds an additional
GPU-specific lowering for `cf.assert`. This was a problem because
`cf.assert` used to be converted as part of `func-to-llvm`.

Note for LLVM integration: If you see failures, add
`-convert-cf-to-llvm` to your pass pipeline.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch 2 times, most recently from 1d2450d to d75a6ed Compare December 20, 2024 13:32
@matthias-springer matthias-springer marked this pull request as ready for review December 20, 2024 13:33
@llvmbot llvmbot added flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp flang:codegen labels Dec 20, 2024
@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch from d75a6ed to 8865986 Compare December 20, 2024 14:46
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this. It's definitely nice to have assert support.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_assert branch from 86f4eff to dbd81f3 Compare January 6, 2025 09:28
@matthias-springer matthias-springer merged commit 599c739 into main Jan 6, 2025
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/gpu_assert branch January 6, 2025 11:00
paulhuggett pushed a commit to paulhuggett/llvm-project that referenced this pull request Jan 7, 2025
This commit add an NVIDIA-specific lowering of `cf.assert` to to
`__assertfail`.

Note: `getUniqueFormatGlobalName`, `getOrCreateFormatStringConstant` and
`getOrDefineFunction` are moved to `GPUOpsLowering.h`, so that they can
be reused.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants