Skip to content

[mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke #123177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2025

Conversation

jeanPerier
Copy link
Contributor

@jeanPerier jeanPerier commented Jan 16, 2025

Second patch of this RFC that update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export to deal with the argument and result attributes.

This patch is made on top of #123176 that modified the CallOpInterface and added the argument and result attributes to llvm.call and llvm.invoke without doing anything with them.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: None (jeanPerier)

Changes

Second patch of this RFC that adds argument and result attributes to llvm.call and update the llvm ir import/export.

This patch is made on top of #123176 that modified the CallOpInterface.


Patch is 23.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123177.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/InstrTypes.h (+11)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+3-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+6-2)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+7-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+45-22)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+21)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+35)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+38-14)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+2)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+20)
  • (added) mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll (+25)
  • (added) mlir/test/Target/LLVMIR/call-argument-attributes.mlir (+17)
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index b8d9cc10292f4a..0e391325eebdce 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
     Attrs = Attrs.addRetAttribute(getContext(), Attr);
   }
 
+  /// Adds attributes to the return value.
+  void addRetAttrs(const AttrBuilder &B) {
+    Attrs = Attrs.addRetAttributes(getContext(), B);
+  }
+
   /// Adds the attribute to the indicated argument
   void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
     assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
     Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
   }
 
+  /// Adds attributes to the indicated argument
+  void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
+    assert(ArgNo < arg_size() && "Out of bounds");
+    Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
+  }
+
   /// removes the attribute from the list of attributes.
   void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
     Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b2281536aa40b6..85f5c6cc8cca07 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -755,7 +755,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   VariadicOfVariadic<LLVM_Type,
                                      "op_bundle_sizes">:$op_bundle_operands,
                   DenseI32ArrayAttr:$op_bundle_sizes,
-                  OptionalAttr<ArrayAttr>:$op_bundle_tags);
+                  OptionalAttr<ArrayAttr>:$op_bundle_tags,
+                  OptionalAttr<DictArrayAttr>:$arg_attrs,
+                  OptionalAttr<DictArrayAttr>:$res_attrs);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 33c9af7c6335a4..86e1d6a04cd096 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -326,14 +326,18 @@ class ModuleImport {
                                            SmallVectorImpl<Type> &types,
                                            SmallVectorImpl<Value> &operands,
                                            bool allowInlineAsm = false);
-  /// Converts the parameter attributes attached to `func` and adds them to the
-  /// `funcOp`.
+  /// Converts the parameter and result attributes attached to `func` and adds
+  /// them to the `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
                                   OpBuilder &builder);
   /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
+  /// Converts the parameter and result attributes attached to `call` and adds
+  /// them to the `callOp`.
+  void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
+                                  OpBuilder &builder);
   /// Returns the builtin type equivalent to the given LLVM dialect type or
   /// nullptr if there is no equivalent. The returned type can be used to create
   /// an attribute for a GlobalOp or a ConstantOp.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 1b62437761ed9d..88fc17ca4fda24 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -228,6 +228,11 @@ class ModuleTranslation {
                             /*recordInsertions=*/false);
   }
 
+  /// Translates parameter attributes of a call and adds them to the returned
+  /// AttrBuilder. Returns failure if any of the translations failed.
+  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx,
+                                                     DictionaryAttr paramAttrs);
+
   /// Gets the named metadata in the LLVM IR module being constructed, creating
   /// it if it does not exist.
   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -346,8 +351,8 @@ class ModuleTranslation {
   convertDialectAttributes(Operation *op,
                            ArrayRef<llvm::Instruction *> instructions);
 
-  /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  /// Returns failure if any of the translations failed.
+  /// Translates parameter attributes of a function and adds them to the
+  /// returned AttrBuilder. Returns failure if any of the translations failed.
   FailureOr<llvm::AttrBuilder>
   convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef1e0222e05f06..6c4988bac7813e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*convergent=*/nullptr,
         /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1331,42 +1335,52 @@ void CallOp::print(OpAsmPrinter &p) {
                            getVarCalleeTypeAttrName(), getCConvAttrName(),
                            getOperandSegmentSizesAttrName(),
                            getOpBundleSizesAttrName(),
-                           getOpBundleTagsAttrName()});
+                           getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+                           getResAttrsAttrName()});
 
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
 
   // Reconstruct the function MLIR function type from operand and result types.
-  p.printFunctionalType(args.getTypes(), getResultTypes());
+  call_interface_impl::printFunctionSignature(
+      p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes());
 }
 
 /// Parses the type of a call operation and resolves the operands if the parsing
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
-  if (parser.parseColonTypeList(types))
+  if (parser.parseColon())
     return failure();
-
-  if (isDirect && types.size() != 1)
-    return parser.emitError(trailingTypesLoc,
-                            "expected direct call to have 1 trailing type");
-  if (!isDirect && types.size() != 2)
-    return parser.emitError(trailingTypesLoc,
-                            "expected indirect call to have 2 trailing types");
-
-  auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
-  if (!funcType)
+  if (!isDirect) {
+    types.emplace_back();
+    if (parser.parseType(types.back()))
+      return failure();
+    if (parser.parseOptionalComma())
+      return parser.emitError(
+          trailingTypesLoc, "expected indirect call to have 2 trailing types");
+  }
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resTypes;
+  if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
+                                                  resTypes, resultAttrs)) {
+    if (isDirect)
+      return parser.emitError(trailingTypesLoc,
+                              "expected direct call to have 1 trailing types");
     return parser.emitError(trailingTypesLoc,
                             "expected trailing function type");
-  if (funcType.getNumResults() > 1)
+  }
+
+  if (resTypes.size() > 1)
     return parser.emitError(trailingTypesLoc,
                             "expected function with 0 or 1 result");
-  if (funcType.getNumResults() == 1 &&
-      llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
+  if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
     return parser.emitError(trailingTypesLoc,
                             "expected a non-void result type");
 
@@ -1374,12 +1388,12 @@ static ParseResult parseCallTypeAndResolveOperands(
   // indirect calls, while the types list is emtpy for direct calls.
   // Append the function input types to resolve the call operation
   // operands.
-  llvm::append_range(types, funcType.getInputs());
+  llvm::append_range(types, argTypes);
   if (parser.resolveOperands(operands, types, parser.getNameLoc(),
                              result.operands))
     return failure();
-  if (funcType.getNumResults() != 0)
-    result.addTypes(funcType.getResults());
+  if (resTypes.size() != 0)
+    result.addTypes(resTypes);
 
   return success();
 }
@@ -1493,8 +1507,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
+  call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
                               getOpBundleSizesAttrName(result.name)))
@@ -1714,7 +1734,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 2084e527773ca8..52f42df60f0015 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     if (callOp.getWillReturnAttr())
       call->addFnAttr(llvm::Attribute::WillReturn);
 
+    if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr())
+      for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
+        if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
+          FailureOr<llvm::AttrBuilder> attrBuilder =
+              moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs);
+          if (failed(attrBuilder))
+            return failure();
+          call->addParamAttrs(argIdx, *attrBuilder);
+        }
+      }
+
+    ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
+    if (resAttrsArray && resAttrsArray.size() == 1)
+      if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        call->addRetAttrs(*attrBuilder);
+      }
+
     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
       llvm::MemoryEffects memEffects =
           llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index eba86f06d09056..f65bf6584d51f2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1641,6 +1641,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
         callOp.setWillReturn(true);
 
+      // Handle parameter and result attributes.
+      convertParameterAttributes(callInst, callOp, builder);
+
       llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
       ModRefInfo othermem = convertModRefInfoFromLLVM(
           memEffects.getModRef(llvm::MemoryEffects::Location::Other));
@@ -2084,6 +2087,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
+                                              CallOpInterface callOp,
+                                              OpBuilder &builder) {
+  auto llvmAttrs = call->getAttributes();
+  SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
+  bool anyArgAttrs = false;
+  for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+    llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
+    if (llvmArgAttrsSet.back().hasAttributes())
+      anyArgAttrs = true;
+  }
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+  if (anyArgAttrs) {
+    SmallVector<DictionaryAttr> argAttrs;
+    for (auto &llvmArgAttrs : llvmArgAttrsSet)
+      argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
+    callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
+  }
+
+  llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
+  if (!llvmResAttr.hasAttributes())
+    return;
+  SmallVector<DictionaryAttr, 1> resAttrs;
+  resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
+  callOp.setResAttrsAttr(getArrayAttr(resAttrs));
+}
+
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 4367100e3aca68..b2d2c1cddca318 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func,
   }
 }
 
+static void convertParameterAttr(llvm::AttrBuilder &attrBuilder,
+                                 llvm::Attribute::AttrKind llvmKind,
+                                 NamedAttribute namedAttr,
+                                 ModuleTranslation &moduleTranslation) {
+  llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+      .Case<TypeAttr>([&](auto typeAttr) {
+        attrBuilder.addTypeAttr(
+            llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
+      })
+      .Case<IntegerAttr>([&](auto intAttr) {
+        attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+      })
+      .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
+      .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
+        attrBuilder.addConstantRangeAttr(
+            llvmKind,
+            llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
+      });
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
@@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
     auto it = attrNameToKindMapping.find(namedAttr.getName());
     if (it != attrNameToKindMapping.end()) {
       llvm::Attribute::AttrKind llvmKind = it->second;
-
-      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
-          .Case<TypeAttr>([&](auto typeAttr) {
-            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
-          })
-          .Case<IntegerAttr>([&](auto intAttr) {
-            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
-          })
-          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
-          .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
-            attrBuilder.addConstantRangeAttr(
-                llvmKind, llvm::ConstantRange(rangeAttr.getLower(),
-                                              rangeAttr.getUpper()));
-          });
+      convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
     } else if (namedAttr.getNameDialect()) {
       if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
         return failure();
@@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
   return attrBuilder;
 }
 
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(CallOp, int argIdx,
+                                         DictionaryAttr paramAttrs) {
+  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+      convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
+    }
+  }
+
+  return attrBuilder;
+}
+
 LogicalResult ModuleTranslation::convertFunctionSignatures() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles, or global initializers that reference functions.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 25806d9d0edd72..14cdcc06625c06 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func private @standard_func_callee()
 
 func.func @call_missing_ptr_type(%arg : i8) {
+  // expected-error@+2 {{expected '('}}
   // expected-error@+1 {{expected direct call to have 1 trailing type}}
   llvm.call @standard_func_callee(%arg) : !llvm.pt...
[truncated]

@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Jan 16, 2025
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I did leave some comments assuming the base commit will land soon.

Base automatically changed from users/jeanPerier/mlir-call-attrs-iface to main February 3, 2025 10:27
@jeanPerier jeanPerier force-pushed the users/jeanPerier/mlir-call-attrs-llvm branch from 076fbfe to b9a67a9 Compare February 3, 2025 16:39
@jeanPerier
Copy link
Contributor Author

Hi @gysit, thanks for the reviews!

I updated the patch to deal with llvm.invoke too.

@jeanPerier jeanPerier requested a review from gysit February 6, 2025 10:31
@jeanPerier jeanPerier changed the title [mlir][LLVM] add argument and result attributes to llvm.call [mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke Feb 6, 2025
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

LGTM modulo some ultra nits.

Co-authored-by: Tobias Gysi <tobias.gysi@nextsilicon.com>
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a few nits, but this is looking very good. Thanks for the work 🙂

Copy link
Contributor Author

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments, LGTM!

@jeanPerier jeanPerier merged commit 99e1308 into main Feb 11, 2025
8 checks passed
@jeanPerier jeanPerier deleted the users/jeanPerier/mlir-call-attrs-llvm branch February 11, 2025 08:39
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…lvm.invoke (llvm#123177)

Update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export
to deal with the argument and result attributes.

This patch is made on top of PR 123176 that modified the
CallOpInterface and added the argument and result attributes to
llvm.call and llvm.invoke without doing anything with them.

RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…lvm.invoke (llvm#123177)

Update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export
to deal with the argument and result attributes.

This patch is made on top of PR 123176 that modified the
CallOpInterface and added the argument and result attributes to
llvm.call and llvm.invoke without doing anything with them.

RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…lvm.invoke (llvm#123177)

Update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export
to deal with the argument and result attributes.

This patch is made on top of PR 123176 that modified the
CallOpInterface and added the argument and result attributes to
llvm.call and llvm.invoke without doing anything with them.

RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants