Skip to content

[mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke #123177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
Attrs = Attrs.addRetAttribute(getContext(), Attr);
}

/// Adds attributes to the return value.
void addRetAttrs(const AttrBuilder &B) {
Attrs = Attrs.addRetAttributes(getContext(), B);
}

/// Adds the attribute to the indicated argument
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
assert(ArgNo < arg_size() && "Out of bounds");
Expand All @@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
}

/// Adds attributes to the indicated argument
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
assert(ArgNo < arg_size() && "Out of bounds");
Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
}

/// removes the attribute from the list of attributes.
void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,18 @@ class ModuleImport {
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter attributes attached to `func` and adds them to
/// the `funcOp`.
/// Converts the parameter and result attributes attached to `func` and adds
/// them to the `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
OpBuilder &builder);
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
/// DictionaryAttr for the LLVM dialect.
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
/// Converts the parameter and result attributes attached to `call` and adds
/// them to the `callOp`.
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
OpBuilder &builder);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
/// Converts the attributes attached to `inst` and adds them to the `op`.
Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class ModuleTranslation {
/*recordInsertions=*/false);
}

/// Translates parameter attributes of a call and adds them to the returned
/// AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
DictionaryAttr paramAttrs);

/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
Expand Down Expand Up @@ -346,8 +351,8 @@ class ModuleTranslation {
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);

/// Translates parameter attributes and adds them to the returned AttrBuilder.
/// Returns failure if any of the translations failed.
/// Translates parameter attributes of a function and adds them to the
/// returned AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);

Expand Down
83 changes: 55 additions & 28 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
getVarCalleeTypeAttrName(), getCConvAttrName(),
getOperandSegmentSizesAttrName(),
getOpBundleSizesAttrName(),
getOpBundleTagsAttrName()});
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});

p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";

// Reconstruct the function MLIR function type from operand and result types.
p.printFunctionalType(args.getTypes(), getResultTypes());
// Reconstruct the MLIR function type from operand and result types.
call_interface_impl::printFunctionSignature(
p, args.getTypes(), getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}

/// Parses the type of a call operation and resolves the operands if the parsing
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
ArrayRef<OpAsmParser::UnresolvedOperand> operands,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
if (parser.parseColon())
return failure();

if (isDirect && types.size() != 1)
return parser.emitError(trailingTypesLoc,
"expected direct call to have 1 trailing type");
if (!isDirect && types.size() != 2)
return parser.emitError(trailingTypesLoc,
"expected indirect call to have 2 trailing types");

auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
if (!funcType)
if (!isDirect) {
types.emplace_back();
if (parser.parseType(types.back()))
return failure();
if (parser.parseOptionalComma())
return parser.emitError(
trailingTypesLoc, "expected indirect call to have 2 trailing types");
}
SmallVector<Type> argTypes;
SmallVector<Type> resTypes;
if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
resTypes, resultAttrs)) {
if (isDirect)
return parser.emitError(trailingTypesLoc,
"expected direct call to have 1 trailing types");
return parser.emitError(trailingTypesLoc,
"expected trailing function type");
if (funcType.getNumResults() > 1)
}

if (resTypes.size() > 1)
return parser.emitError(trailingTypesLoc,
"expected function with 0 or 1 result");
if (funcType.getNumResults() == 1 &&
llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
return parser.emitError(trailingTypesLoc,
"expected a non-void result type");

// The head element of the types list matches the callee type for
// indirect calls, while the types list is emtpy for direct calls.
// Append the function input types to resolve the call operation
// operands.
llvm::append_range(types, funcType.getInputs());
llvm::append_range(types, argTypes);
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
result.operands))
return failure();
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
if (resTypes.size() != 0)
result.addTypes(resTypes);

return success();
}
Expand Down Expand Up @@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

// Parse the trailing type list and resolve the operands.
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
Expand Down Expand Up @@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
getCConvAttrName(), getVarCalleeTypeAttrName(),
getOpBundleSizesAttrName(),
getOpBundleTagsAttrName()});
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});

p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
p.printFunctionalType(
llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
getResultTypes());
call_interface_impl::printFunctionSignature(
p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}

// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
Expand All @@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
// ( `vararg(` var-callee-type `)` )?
// ( `[` op-bundles-list `]` )?
// attribute-dict? `:` (type `,`)? function-type
// attribute-dict? `:` (type `,`)?
// function-type-with-argument-attributes
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
Expand Down Expand Up @@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

// Parse the trailing type list and resolve the function operands.
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));

if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
Expand Down
39 changes: 39 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,39 @@ static void convertLinkerOptionsOp(ArrayAttr options,
linkerMDNode->addOperand(listMDNode);
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
!argAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
if (resAttrsArray && resAttrsArray.size() > 0) {
if (resAttrsArray.size() != 1)
return mlir::emitError(callOp.getLoc(),
"llvm.func cannot have multiple results");
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
!resAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand Down Expand Up @@ -265,6 +298,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getWillReturnAttr())
call->addFnAttr(llvm::Attribute::WillReturn);

if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
return failure();

if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
llvm::MemoryEffects memEffects =
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
Expand Down Expand Up @@ -372,6 +408,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
if (failed(
convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
if (invOp->getNumResults() != 0) {
Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1786,6 +1788,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
else
Expand Down Expand Up @@ -2149,6 +2154,37 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
}
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
SmallVector<Attribute> attrs;
for (auto &dict : dictAttrs)
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
return builder.getArrayAttr(attrs);
};
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}

llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}

template <typename Op>
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
Expand Down
Loading