Skip to content

[MLIR][LLVM] Tail call support for inline asm op #140826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

bcardosolopes
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 21, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-backend-amdgpu

Author: Bruno Cardoso Lopes (bcardosolopes)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/140826.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+2-1)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp (+2-1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+4)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+2-2)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+8-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 61ba8f7b991c8..ba13806a38a4c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -758,9 +758,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     the LLVM function type that uses an explicit void type to model functions
     that do not return a value.
 
-    If this operatin has the `no_inline` attribute, then this specific function call 
-    will never be inlined. The opposite behavior will occur if the call has `always_inline` 
-    attribute. The `inline_hint` attribute indicates that it is desirable to inline 
+    If this operatin has the `no_inline` attribute, then this specific function call
+    will never be inlined. The opposite behavior will occur if the call has `always_inline`
+    attribute. The `inline_hint` attribute indicates that it is desirable to inline
     this function call.
 
     Examples:
@@ -2298,6 +2298,9 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
     written, or referenced.
     Attempting to define or reference any symbol or any global behavior is
     considered undefined behavior at this time.
+    If `tail_call_kind` is used, the operation behaves like the specified
+    tail call kind. The `musttail` kind it's not available for this operation,
+    since it isn't supported by LLVM's inline asm.
   }];
   let arguments = (
     ins Variadic<LLVM_Type>:$operands,
@@ -2305,6 +2308,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
         StrAttr:$constraints,
         UnitAttr:$has_side_effects,
         UnitAttr:$is_align_stack,
+        OptionalAttr<
+        DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
         OptionalAttr<
           DefaultValuedAttr<AsmATTOrIntel, "AsmDialect::AD_ATT">>:$asm_dialect,
         OptionalAttr<ArrayAttr>:$operand_attrs);
@@ -2314,6 +2319,7 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
   let assemblyFormat = [{
     (`has_side_effects` $has_side_effects^)?
     (`is_align_stack` $is_align_stack^)?
+    (`tail_call_kind` `=` $tail_call_kind^)?
     (`asm_dialect` `=` $asm_dialect^)?
     (`operand_attrs` `=` $operand_attrs^)?
     attr-dict
@@ -2326,6 +2332,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
       return "elementtype";
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0694cf27faff4..ae07666574260 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -439,7 +439,8 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
           op,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
-          /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+          /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+          /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
       return success();
     }
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index eb3558d2460e4..c5e5094a35010 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -572,6 +572,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
       /*constraints=*/constraintStr,
       /*has_side_effects=*/true,
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index f2e71e7795c3e..9b83fd190b9eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -140,6 +140,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
       /*constraints=*/registerConstraints.data(),
       /*has_side_effects=*/interfaceOp.hasSideEffect(),
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d8abf6fd41301..c7528c970a4ba 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -4042,6 +4042,21 @@ LogicalResult LLVM::masked_scatter::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// InlineAsmOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult InlineAsmOp::verify() {
+  if (!getTailCallKindAttr())
+    return success();
+
+  if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
+    return emitOpError(
+        "tail call kind 'musttail' is not supported by this operation");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 3fc05c8cb8707..5bce7c7625e8d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -41,7 +41,8 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
   auto asmOp = b.create<LLVM::InlineAsmOp>(
       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
       /*constraints=*/asmCstr, /*has_side_effects=*/false,
-      /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+      /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+      /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
   return asmOp.getResult(0);
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c954dffb376bb..37387130ee012 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Operator.h"
@@ -507,6 +508,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::CallInst *inst = builder.CreateCall(
         inlineAsmInst,
         moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
+    if (inlineAsmOp.getTailCallKindAttr())
+      inst->setTailCallKind(convertTailCallKindToLLVM(
+          inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
     if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
       llvm::AttributeList attrList;
       for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8945ae933dd65..8041480ac687f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2201,6 +2201,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                 builder.getStringAttr(asmI->getAsmString()),
                 builder.getStringAttr(asmI->getConstraintString()),
                 asmI->hasSideEffects(), asmI->isAlignStack(),
+                callInst->isTailCall()
+                    ? TailCallKindAttr::get(mlirModule.getContext(),
+                                            TailCallKind::Tail)
+                    : nullptr,
                 AsmDialectAttr::get(
                     mlirModule.getContext(),
                     convertAsmDialectFromLLVM(asmI->getDialect())),
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f5adf4b3bf33d..251ca716c7a7a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1882,3 +1882,11 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
   %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
   llvm.return %0 : !llvm.array<2 x f64>
 }
+
+// ----
+
+llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
+  // expected-error@+1 {{op tail call kind 'musttail' is not supported}}
+  %8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 68ef47c3f42f1..d92abd4408249 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -554,8 +554,8 @@ define i32 @inlineasm(i32 %arg1) {
 define void @inlineasm2() {
   %p = alloca ptr, align 8
   ; CHECK: {{.*}} = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  ; CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
-  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
+  ; CHECK-NEXT: llvm.inline_asm has_side_effects tail_call_kind = <tail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
+  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
   ret void
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 237612244d8de..7742259e7a478 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2081,8 +2081,14 @@ llvm.func @useInlineAsm(%arg0: i32, %arg1 : !llvm.ptr) {
   // CHECK-NEXT:  call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
   %5 = llvm.inline_asm "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
-  // CHECK-NEXT:  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
-  %6 = llvm.inline_asm has_side_effects operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+  // CHECK-NEXT:  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
+  %6 = llvm.inline_asm has_side_effects tail_call_kind = <tail> operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+
+  // CHECK-NEXT:  = call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %7 = llvm.inline_asm tail_call_kind = <none> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+
+  // CHECK-NEXT:  notail call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %8 = llvm.inline_asm tail_call_kind = <notail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
   llvm.return
 }

@llvmbot
Copy link
Member

llvmbot commented May 21, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Bruno Cardoso Lopes (bcardosolopes)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/140826.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+2-1)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+15)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp (+2-1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+4)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+2-2)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+8-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 61ba8f7b991c8..ba13806a38a4c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -758,9 +758,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     the LLVM function type that uses an explicit void type to model functions
     that do not return a value.
 
-    If this operatin has the `no_inline` attribute, then this specific function call 
-    will never be inlined. The opposite behavior will occur if the call has `always_inline` 
-    attribute. The `inline_hint` attribute indicates that it is desirable to inline 
+    If this operatin has the `no_inline` attribute, then this specific function call
+    will never be inlined. The opposite behavior will occur if the call has `always_inline`
+    attribute. The `inline_hint` attribute indicates that it is desirable to inline
     this function call.
 
     Examples:
@@ -2298,6 +2298,9 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
     written, or referenced.
     Attempting to define or reference any symbol or any global behavior is
     considered undefined behavior at this time.
+    If `tail_call_kind` is used, the operation behaves like the specified
+    tail call kind. The `musttail` kind it's not available for this operation,
+    since it isn't supported by LLVM's inline asm.
   }];
   let arguments = (
     ins Variadic<LLVM_Type>:$operands,
@@ -2305,6 +2308,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
         StrAttr:$constraints,
         UnitAttr:$has_side_effects,
         UnitAttr:$is_align_stack,
+        OptionalAttr<
+        DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
         OptionalAttr<
           DefaultValuedAttr<AsmATTOrIntel, "AsmDialect::AD_ATT">>:$asm_dialect,
         OptionalAttr<ArrayAttr>:$operand_attrs);
@@ -2314,6 +2319,7 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
   let assemblyFormat = [{
     (`has_side_effects` $has_side_effects^)?
     (`is_align_stack` $is_align_stack^)?
+    (`tail_call_kind` `=` $tail_call_kind^)?
     (`asm_dialect` `=` $asm_dialect^)?
     (`operand_attrs` `=` $operand_attrs^)?
     attr-dict
@@ -2326,6 +2332,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
       return "elementtype";
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0694cf27faff4..ae07666574260 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -439,7 +439,8 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
           op,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
-          /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+          /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+          /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
       return success();
     }
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index eb3558d2460e4..c5e5094a35010 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -572,6 +572,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
       /*constraints=*/constraintStr,
       /*has_side_effects=*/true,
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index f2e71e7795c3e..9b83fd190b9eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -140,6 +140,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
       /*constraints=*/registerConstraints.data(),
       /*has_side_effects=*/interfaceOp.hasSideEffect(),
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d8abf6fd41301..c7528c970a4ba 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -4042,6 +4042,21 @@ LogicalResult LLVM::masked_scatter::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// InlineAsmOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult InlineAsmOp::verify() {
+  if (!getTailCallKindAttr())
+    return success();
+
+  if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
+    return emitOpError(
+        "tail call kind 'musttail' is not supported by this operation");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 3fc05c8cb8707..5bce7c7625e8d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -41,7 +41,8 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
   auto asmOp = b.create<LLVM::InlineAsmOp>(
       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
       /*constraints=*/asmCstr, /*has_side_effects=*/false,
-      /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+      /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+      /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
   return asmOp.getResult(0);
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c954dffb376bb..37387130ee012 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Operator.h"
@@ -507,6 +508,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::CallInst *inst = builder.CreateCall(
         inlineAsmInst,
         moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
+    if (inlineAsmOp.getTailCallKindAttr())
+      inst->setTailCallKind(convertTailCallKindToLLVM(
+          inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
     if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
       llvm::AttributeList attrList;
       for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8945ae933dd65..8041480ac687f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2201,6 +2201,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                 builder.getStringAttr(asmI->getAsmString()),
                 builder.getStringAttr(asmI->getConstraintString()),
                 asmI->hasSideEffects(), asmI->isAlignStack(),
+                callInst->isTailCall()
+                    ? TailCallKindAttr::get(mlirModule.getContext(),
+                                            TailCallKind::Tail)
+                    : nullptr,
                 AsmDialectAttr::get(
                     mlirModule.getContext(),
                     convertAsmDialectFromLLVM(asmI->getDialect())),
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f5adf4b3bf33d..251ca716c7a7a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1882,3 +1882,11 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
   %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
   llvm.return %0 : !llvm.array<2 x f64>
 }
+
+// ----
+
+llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
+  // expected-error@+1 {{op tail call kind 'musttail' is not supported}}
+  %8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 68ef47c3f42f1..d92abd4408249 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -554,8 +554,8 @@ define i32 @inlineasm(i32 %arg1) {
 define void @inlineasm2() {
   %p = alloca ptr, align 8
   ; CHECK: {{.*}} = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  ; CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
-  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
+  ; CHECK-NEXT: llvm.inline_asm has_side_effects tail_call_kind = <tail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
+  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
   ret void
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 237612244d8de..7742259e7a478 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2081,8 +2081,14 @@ llvm.func @useInlineAsm(%arg0: i32, %arg1 : !llvm.ptr) {
   // CHECK-NEXT:  call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
   %5 = llvm.inline_asm "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
-  // CHECK-NEXT:  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
-  %6 = llvm.inline_asm has_side_effects operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+  // CHECK-NEXT:  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
+  %6 = llvm.inline_asm has_side_effects tail_call_kind = <tail> operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+
+  // CHECK-NEXT:  = call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %7 = llvm.inline_asm tail_call_kind = <none> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+
+  // CHECK-NEXT:  notail call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %8 = llvm.inline_asm tail_call_kind = <notail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
   llvm.return
 }

@bcardosolopes bcardosolopes requested review from gysit and Dinistro May 21, 2025 00:48
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo one question.

}];
let arguments = (
ins Variadic<LLVM_Type>:$operands,
StrAttr:$asm_string,
StrAttr:$constraints,
UnitAttr:$has_side_effects,
UnitAttr:$is_align_stack,
OptionalAttr<
DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,

nit: same indent as below. Is the OptionalAttr needed? Couldn't we use TailCallKind::None as there is no tailcall kind defined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid printing tail_call_kind = <none> every time InlineAsmOp is used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed indentation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But now we once again have two ways of modelling the same state, which is something that can be a bit messy to deal with.
According to the doc (https://mlir.llvm.org/docs/DefiningDialects/Operations/#attributes-with-default-values):

The generated operation printing function will not print default-valued attributes when the attribute value is equal to the default.

If that doesn't work in this case, then the assembly format is probably broken in some way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants