From 4655bd66ce1fe4ab4d608e6fd3ec61924b2e0720 Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Mon, 14 Oct 2024 23:30:57 +0800 Subject: [PATCH] [mlir][LLVMIR] Add operand bundle support for llvm.intr.assume --- .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 1 + .../mlir/Dialect/LLVMIR/LLVMDialect.td | 2 + .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 44 +++++++-- .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 25 +++-- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 18 +--- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 2 +- .../include/mlir/Target/LLVMIR/ModuleImport.h | 2 + mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 96 ++++++++++++------- .../LLVMIR/LLVMIRToLLVMTranslation.cpp | 6 ++ .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 16 +++- .../Dialect/NVVM/LLVMIRToNVVMTranslation.cpp | 6 ++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 32 ++++++- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 37 ++++++- .../expand-then-convert-to-llvm.mlir | 2 +- .../MemRefToLLVM/memref-to-llvm.mlir | 4 +- mlir/test/Dialect/LLVMIR/inlining.mlir | 4 +- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 27 ++++++ mlir/test/Target/LLVMIR/Import/intrinsic.ll | 12 ++- .../test/Target/LLVMIR/llvmir-intrinsics.mlir | 15 +++ mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 2 +- 20 files changed, 276 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index 0e38325f9891ac..e81db32bcaad03 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -71,6 +71,7 @@ class ArmSME_IntrOp immArgPositions=*/immArgPositions, /*list immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td index edcc34461f2f26..63a1a7a9888c2a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -58,6 +58,8 @@ def LLVM_Dialect : Dialect { static StringRef getStructRetAttrName() { return "llvm.sret"; } static StringRef getWriteOnlyAttrName() { return "llvm.writeonly"; } static StringRef getZExtAttrName() { return "llvm.zeroext"; } + static StringRef getOpBundleSizesAttrName() { return "op_bundle_sizes"; } + static StringRef getOpBundleTagsAttrName() { return "op_bundle_tags"; } // TODO Restrict the usage of this to parameter attributes once there is an // alternative way of modeling memory effects on FunctionOpInterface. /// Name of the attribute that will cause the creation of a readnone memory diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 448a171cf3e412..71a65866070473 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -120,7 +120,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">; def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"] + /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3], + /*immArgAttrNames=*/["rw", "hint", "cache"] > { let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache); } @@ -173,7 +174,8 @@ class LLVM_MemcpyIntrOpBase : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[3], + /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg:$dst, Arg:$src, AnySignlessInteger:$len, I1Attr:$isVolatile); @@ -203,7 +205,8 @@ def LLVM_MemcpyInlineOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], + /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg:$dst, Arg:$src, APIntAttr:$len, I1Attr:$isVolatile); @@ -229,7 +232,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], DeclareOpInterfaceMethods, DeclareOpInterfaceMethods], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[3], + /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg:$dst, I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile); // Append the alias attributes defined by LLVM_IntrOpBase. @@ -283,7 +287,8 @@ def LLVM_NoAliasScopeDeclOp class LLVM_LifetimeBaseOp : LLVM_ZeroResultIntrOp], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[0], + /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; } @@ -303,7 +308,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2], [DeclareOpInterfaceMethods], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[1], + /*immArgAttrNames=*/["size"]> { let arguments = (ins LLVM_DefaultPointer:$start, I64Attr:$size, LLVM_AnyPointer:$ptr); @@ -365,7 +371,7 @@ class LLVM_ConstrainedIntr mlirOperands; SmallVector mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands.take_front( }] # numArgs # [{), + llvmOperands.take_front( }] # numArgs # [{), {}, false, {}, {}, mlirOperands, mlirAttrs))) { return failure(); } @@ -426,7 +432,26 @@ def LLVM_USHLSat : LLVM_BinarySameArgsIntrOpI<"ushl.sat">; // def LLVM_AssumeOp - : LLVM_ZeroResultIntrOp<"assume", []>, Arguments<(ins I1:$cond)>; + : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[], + /*requiresAccessGroup=*/0, + /*requiresAliasAnalysis=*/0, + /*requiresOpBundles=*/1> { + dag args = (ins I1:$cond); + let arguments = !con(args, opBundleArgs); + + let assemblyFormat = [{ + $cond + ( custom($op_bundle_operands, type($op_bundle_operands), + $op_bundle_tags)^ )? + `:` type($cond) attr-dict + }]; + + let builders = [ + OpBuilder<(ins "Value":$cond)> + ]; + + let hasVerifier = 1; +} def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0], [Pure, SameOperandsAndResultType]> { @@ -989,7 +1014,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">; def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> { + /*requiresOpBundles=*/0, /*immArgPositions=*/[0], + /*immArgAttrNames=*/["failureKind"]> { let arguments = (ins I8Attr:$failureKind); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index c3d352d8d0dd48..a38dafa4d9cf34 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -291,7 +291,7 @@ class LLVM_IntrOpBase overloadedResults, list overloadedOperands, list traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, - bit requiresFastmath = 0, + bit requiresFastmath = 0, bit requiresOpBundles = 0, list immArgPositions = [], list immArgAttrNames = []> : LLVM_OpBase:$noalias_scopes, OptionalAttr:$tbaa), (ins ))); + dag opBundleArgs = !if(!gt(requiresOpBundles, 0), + (ins VariadicOfVariadic:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + OptionalAttr:$op_bundle_tags), + (ins )); string llvmEnumName = enumName; string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}"; string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}"; @@ -336,6 +342,8 @@ class LLVM_IntrOpBase mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( llvmOperands, + llvmOpBundles, + }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, }] # immArgPositionsCpp # [{, }] # immArgAttrNamesCpp # [{, mlirOperands, @@ -381,12 +389,14 @@ class LLVM_IntrOp overloadedResults, list overloadedOperands, list traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, bit requiresFastmath = 0, + bit requiresOpBundles = 0, list immArgPositions = [], list immArgAttrNames = []> : LLVM_IntrOpBase; + requiresFastmath, requiresOpBundles, immArgPositions, + immArgAttrNames>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -406,11 +416,13 @@ class LLVM_ZeroResultIntrOp overloadedOperands = [], list traits = [], bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, + bit requiresOpBundles = 0, list immArgPositions = [], list immArgAttrNames = []> : LLVM_IntrOp; + /*requiresFastMath=*/0, requiresOpBundles, immArgPositions, + immArgAttrNames>; // Base class for LLVM intrinsic operations returning one result. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is @@ -422,11 +434,12 @@ class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], list traits = [], bit requiresFastmath = 0, - list immArgPositions = [], - list immArgAttrNames = []> + list immArgPositions = [], + list immArgAttrNames = []> : LLVM_IntrOp; + requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + immArgAttrNames>; def LLVM_OneResultOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 000d92f9ea3bcb..d388de3960f2b2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -559,11 +559,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - DefaultValuedProperty< - ArrayProperty, - "ArrayRef{}", - "SmallVector{}" - >:$op_bundle_tags); + OptionalAttr:$op_bundle_tags); let results = (outs Optional:$result); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -678,11 +674,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - DefaultValuedProperty< - ArrayProperty, - "ArrayRef{}", - "SmallVector{}" - >:$op_bundle_tags); + OptionalAttr:$op_bundle_tags); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -1930,11 +1922,7 @@ def LLVM_CallIntrinsicOp VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - DefaultValuedProperty< - ArrayProperty, - "ArrayRef{}", - "SmallVector{}" - >:$op_bundle_tags); + OptionalAttr:$op_bundle_tags); let results = (outs Optional:$results); let llvmBuilder = [{ return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation); diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index b80d9ae88910c4..d6591119bc6e18 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,7 +98,7 @@ class ROCDL_IntrOp overloadedResults, LLVM_IntrOpBase; + requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>; //===----------------------------------------------------------------------===// // ROCDL special register op definitions diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 436675793062eb..229efd8193555c 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -239,6 +239,8 @@ class ModuleImport { /// corresponding MLIR attribute names. LogicalResult convertIntrinsicArguments(ArrayRef values, + ArrayRef opBundles, + bool requiresOpBundles, ArrayRef immArgPositions, ArrayRef immArgAttrNames, SmallVectorImpl &valuesOut, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 006d412936a337..3a38065560940f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -241,13 +241,18 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, static void printOpBundles(OpAsmPrinter &p, Operation *op, OperandRangeRange opBundleOperands, TypeRangeRange opBundleOperandTypes, - ArrayRef opBundleTags) { + std::optional opBundleTags) { + if (opBundleOperands.empty()) + return; + assert(opBundleTags && "expect operand bundle tags"); + p << "["; llvm::interleaveComma( - llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p, + llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p, [&p](auto bundle) { + auto bundleTag = llvm::cast(std::get<2>(bundle)).getValue(); printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle), - std::get<2>(bundle)); + bundleTag); }); p << "]"; } @@ -256,7 +261,7 @@ static ParseResult parseOneOpBundle( OpAsmParser &p, SmallVector> &opBundleOperands, SmallVector> &opBundleOperandTypes, - SmallVector &opBundleTags) { + SmallVector &opBundleTags) { SMLoc currentParserLoc = p.getCurrentLocation(); SmallVector operands; SmallVector types; @@ -276,7 +281,7 @@ static ParseResult parseOneOpBundle( opBundleOperands.push_back(std::move(operands)); opBundleOperandTypes.push_back(std::move(types)); - opBundleTags.push_back(std::move(tag)); + opBundleTags.push_back(StringAttr::get(p.getContext(), tag)); return success(); } @@ -285,16 +290,17 @@ static std::optional parseOpBundles( OpAsmParser &p, SmallVector> &opBundleOperands, SmallVector> &opBundleOperandTypes, - SmallVector &opBundleTags) { + ArrayAttr &opBundleTags) { if (p.parseOptionalLSquare()) return std::nullopt; if (succeeded(p.parseOptionalRSquare())) return success(); + SmallVector opBundleTagAttrs; auto bundleParser = [&] { return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes, - opBundleTags); + opBundleTagAttrs); }; if (p.parseCommaSeparatedList(bundleParser)) return failure(); @@ -302,6 +308,8 @@ static std::optional parseOpBundles( if (p.parseRSquare()) return failure(); + opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs); + return success(); } @@ -1039,7 +1047,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, - /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1066,7 +1074,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, - /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1079,7 +1087,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, - /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1092,7 +1100,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, - /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt, + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1192,12 +1200,20 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { template static LogicalResult verifyOperandBundles(OpType &op) { OperandRangeRange opBundleOperands = op.getOpBundleOperands(); - ArrayRef opBundleTags = op.getOpBundleTags(); + std::optional opBundleTags = op.getOpBundleTags(); - if (opBundleTags.size() != opBundleOperands.size()) + auto isStringAttr = [](Attribute tagAttr) { + return isa(tagAttr); + }; + if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr)) + return op.emitError("operand bundle tag must be a StringAttr"); + + size_t numOpBundles = opBundleOperands.size(); + size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0; + if (numOpBundles != numOpBundleTags) return op.emitError("expected ") - << opBundleOperands.size() - << " operand bundle tags, but actually got " << opBundleTags.size(); + << numOpBundles << " operand bundle tags, but actually got " + << numOpBundleTags; return success(); } @@ -1329,7 +1345,8 @@ void CallOp::print(OpAsmPrinter &p) { {getCalleeAttrName(), getTailCallKindAttrName(), getVarCalleeTypeAttrName(), getCConvAttrName(), getOperandSegmentSizesAttrName(), - getOpBundleSizesAttrName()}); + getOpBundleSizesAttrName(), + getOpBundleTagsAttrName()}); p << " : "; if (!isDirect) @@ -1437,7 +1454,7 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; SmallVector> opBundleOperands; SmallVector> opBundleOperandTypes; - SmallVector opBundleTags; + ArrayAttr opBundleTags; // Default to C Calling Convention if no keyword is provided. result.addAttribute( @@ -1483,9 +1500,9 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { parser, opBundleOperands, opBundleOperandTypes, opBundleTags); result && failed(*result)) return failure(); - if (!opBundleTags.empty()) - result.getOrAddProperties().op_bundle_tags = - std::move(opBundleTags); + if (opBundleTags && !opBundleTags.empty()) + result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(), + opBundleTags); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1525,8 +1542,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, - normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal, - unwind); + normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, @@ -1535,7 +1551,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, ValueRange unwindOps) { build(builder, state, tys, /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr, - nullptr, {}, std::nullopt, normal, unwind); + nullptr, {}, {}, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, @@ -1544,7 +1560,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, Block *unwind, ValueRange unwindOps) { build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps, - nullptr, nullptr, {}, std::nullopt, normal, unwind); + nullptr, nullptr, {}, {}, normal, unwind); } SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { @@ -1634,7 +1650,8 @@ void InvokeOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), {getCalleeAttrName(), getOperandSegmentSizeAttr(), getCConvAttrName(), getVarCalleeTypeAttrName(), - getOpBundleSizesAttrName()}); + getOpBundleSizesAttrName(), + getOpBundleTagsAttrName()}); p << " : "; if (!isDirect) @@ -1657,7 +1674,7 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { TypeAttr varCalleeType; SmallVector> opBundleOperands; SmallVector> opBundleOperandTypes; - SmallVector opBundleTags; + ArrayAttr opBundleTags; Block *normalDest, *unwindDest; SmallVector normalOperands, unwindOperands; Builder &builder = parser.getBuilder(); @@ -1703,9 +1720,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { parser, opBundleOperands, opBundleOperandTypes, opBundleTags); result && failed(*result)) return failure(); - if (!opBundleTags.empty()) - result.getOrAddProperties().op_bundle_tags = - std::move(opBundleTags); + if (opBundleTags && !opBundleTags.empty()) + result.addAttribute( + InvokeOp::getOpBundleTagsAttrName(result.name).getValue(), + opBundleTags); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -3332,7 +3350,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, mlir::StringAttr intrin, mlir::ValueRange args) { build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, FastmathFlagsAttr{}, - /*op_bundle_operands=*/{}); + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); } void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, @@ -3340,14 +3358,14 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, mlir::LLVM::FastmathFlagsAttr fastMathFlags) { build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, fastMathFlags, - /*op_bundle_operands=*/{}); + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); } void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, mlir::Type resultType, mlir::StringAttr intrin, mlir::ValueRange args) { build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{}, - /*op_bundle_operands=*/{}); + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); } void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, @@ -3355,7 +3373,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, mlir::StringAttr intrin, mlir::ValueRange args, mlir::LLVM::FastmathFlagsAttr fastMathFlags) { build(builder, state, resultTypes, intrin, args, fastMathFlags, - /*op_bundle_operands=*/{}); + /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}); } //===----------------------------------------------------------------------===// @@ -3412,6 +3430,18 @@ void InlineAsmOp::getEffects( } } +//===----------------------------------------------------------------------===// +// AssumeOp (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, + mlir::Value cond) { + return build(builder, state, cond, /*op_bundle_operands=*/{}, + /*op_bundle_tags=*/{}); +} + +LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); } + //===----------------------------------------------------------------------===// // masked_gather (intrinsic) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index d034e576dfc579..4fd043c7c93e68 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -68,6 +68,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder, if (isConvertibleIntrinsic(intrinsicID)) { SmallVector args(inst->args()); ArrayRef llvmOperands(args); + + SmallVector llvmOpBundles; + llvmOpBundles.reserve(inst->getNumOperandBundles()); + for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i) + llvmOpBundles.push_back(inst->getOperandBundleAt(i)); + #include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc" } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index a8595d14ccf2e5..2084e527773ca8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -114,17 +114,27 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, } static SmallVector -convertOperandBundles(OperandRangeRange bundleOperands, - ArrayRef bundleTags, +convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, LLVM::ModuleTranslation &moduleTranslation) { SmallVector bundles; bundles.reserve(bundleOperands.size()); - for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags)) + for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) { + StringRef tag = cast(tagAttr).getValue(); bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation)); + } return bundles; } +static SmallVector +convertOperandBundles(OperandRangeRange bundleOperands, + std::optional bundleTags, + LLVM::ModuleTranslation &moduleTranslation) { + if (!bundleTags) + return {}; + return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); +} + /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index bc830a77f3c580..2c0b665ad0d833 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -50,6 +50,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder, if (isConvertibleIntrinsic(intrinsicID)) { SmallVector args(inst->args()); ArrayRef llvmOperands(args); + + SmallVector llvmOpBundles; + llvmOpBundles.reserve(inst->getNumOperandBundles()); + for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i) + llvmOpBundles.push_back(inst->getOperandBundleAt(i)); + #include "mlir/Dialect/LLVMIR/NVVMFromLLVMIRConversions.inc" } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 4ff1f1135b0a88..7c3beaa850b917 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1291,7 +1291,8 @@ ModuleImport::convertValues(ArrayRef values) { } LogicalResult ModuleImport::convertIntrinsicArguments( - ArrayRef values, ArrayRef immArgPositions, + ArrayRef values, ArrayRef opBundles, + bool requiresOpBundles, ArrayRef immArgPositions, ArrayRef immArgAttrNames, SmallVectorImpl &valuesOut, SmallVectorImpl &attrsOut) { assert(immArgPositions.size() == immArgAttrNames.size() && @@ -1321,6 +1322,35 @@ LogicalResult ModuleImport::convertIntrinsicArguments( valuesOut.push_back(*mlirValue); } + SmallVector opBundleSizes; + SmallVector opBundleTagAttrs; + if (requiresOpBundles) { + opBundleSizes.reserve(opBundles.size()); + opBundleTagAttrs.reserve(opBundles.size()); + + for (const llvm::OperandBundleUse &bundle : opBundles) { + opBundleSizes.push_back(bundle.Inputs.size()); + opBundleTagAttrs.push_back(StringAttr::get(context, bundle.getTagName())); + + for (const llvm::Use &opBundleOperand : bundle.Inputs) { + auto operandMlirValue = convertValue(opBundleOperand.get()); + if (failed(operandMlirValue)) + return failure(); + valuesOut.push_back(*operandMlirValue); + } + } + + auto opBundleSizesAttr = DenseI32ArrayAttr::get(context, opBundleSizes); + auto opBundleSizesAttrNameAttr = + StringAttr::get(context, LLVMDialect::getOpBundleSizesAttrName()); + attrsOut.push_back({opBundleSizesAttrNameAttr, opBundleSizesAttr}); + + auto opBundleTagsAttr = ArrayAttr::get(context, opBundleTagAttrs); + auto opBundleTagsAttrNameAttr = + StringAttr::get(context, LLVMDialect::getOpBundleTagsAttrName()); + attrsOut.push_back({opBundleTagsAttrNameAttr, opBundleTagsAttr}); + } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index add0a31c114f8d..22ae5a94e9e9a7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -55,6 +55,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include #include #define DEBUG_TYPE "llvm-dialect-to-llvm-ir" @@ -854,8 +855,40 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal " "length"); + SmallVector opBundles; + size_t numOpBundleOperands = 0; + auto opBundleSizesAttr = cast_if_present( + intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName())); + auto opBundleTagsAttr = cast_if_present( + intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName())); + + if (opBundleSizesAttr && opBundleTagsAttr) { + ArrayRef opBundleSizes = opBundleSizesAttr.asArrayRef(); + assert(opBundleSizes.size() == opBundleTagsAttr.size() && + "operand bundles and tags do not match"); + + numOpBundleOperands = + std::reduce(opBundleSizes.begin(), opBundleSizes.end()); + assert(numOpBundleOperands <= intrOp->getNumOperands() && + "operand bundle operands is more than the number of operands"); + + ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands); + size_t nextOperandIdx = 0; + opBundles.reserve(opBundleSizesAttr.size()); + + for (auto [opBundleTagAttr, bundleSize] : + llvm::zip(opBundleTagsAttr, opBundleSizes)) { + auto bundleTag = cast(opBundleTagAttr).str(); + auto bundleOperands = moduleTranslation.lookupValues( + operands.slice(nextOperandIdx, bundleSize)); + opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands)); + nextOperandIdx += bundleSize; + } + } + // Map operands and attributes to LLVM values. - auto operands = moduleTranslation.lookupValues(intrOp->getOperands()); + auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands); + auto operands = moduleTranslation.lookupValues(opOperands); SmallVector args(immArgPositions.size() + operands.size()); for (auto [immArgPos, immArgName] : llvm::zip(immArgPositions, immArgAttrNames)) { @@ -890,7 +923,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration( module, intrinsic, overloadedTypes); - return builder.CreateCall(llvmIntr, args); + return builder.CreateCall(llvmIntr, args, opBundles); } /// Given a single MLIR operation, create the corresponding LLVM IR operation diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index b86103422b0745..55b1bc9c545a85 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -684,7 +684,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf // CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 // CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64 // CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64 -// CHECK: "llvm.intr.assume"(%[[CMP]]) : (i1) -> () +// CHECK: llvm.intr.assume %[[CMP]] : i1 // CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32 // CHECK: return %[[VAL]] : f32 diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 9dc22abf143bf0..48dc9079333d4f 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -160,7 +160,7 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64 // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64 // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64 - // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> () + // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1 memref.assume_alignment %0, 16 : memref<4x4xf16> return } @@ -177,7 +177,7 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64 // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64 - // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> () + // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1 memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>> return } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index f9551e311df59f..0b7ca3f2bb048a 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -18,7 +18,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 { "llvm.intr.memset"(%ptr, %byte, %0) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> () "llvm.intr.memmove"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> () "llvm.intr.memcpy"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> () - "llvm.intr.assume"(%true) : (i1) -> () + llvm.intr.assume %true : i1 llvm.fence release %2 = llvm.atomicrmw add %ptr, %0 monotonic : !llvm.ptr, i32 %3 = llvm.cmpxchg %ptr, %0, %1 acq_rel monotonic : !llvm.ptr, i32 @@ -44,7 +44,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 { // CHECK: "llvm.intr.memset"(%[[PTR]] // CHECK: "llvm.intr.memmove"(%[[PTR]], %[[PTR]] // CHECK: "llvm.intr.memcpy"(%[[PTR]], %[[PTR]] -// CHECK: "llvm.intr.assume" +// CHECK: llvm.intr.assume // CHECK: llvm.fence release // CHECK: llvm.atomicrmw add %[[PTR]], %[[CST]] monotonic // CHECK: llvm.cmpxchg %[[PTR]], %[[CST]], %[[RES]] acq_rel monotonic diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 3062cdc38c0abb..b8ce7db795a1d1 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -836,3 +836,30 @@ llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) { llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> () llvm.return } + +// CHECK-LABEL: @test_assume_intr_no_opbundle +llvm.func @test_assume_intr_no_opbundle(%arg0 : !llvm.ptr) { + %0 = llvm.mlir.constant(1 : i1) : i1 + // CHECK: llvm.intr.assume %0 : i1 + llvm.intr.assume %0 : i1 + llvm.return +} + +// CHECK-LABEL: @test_assume_intr_empty_opbundle +llvm.func @test_assume_intr_empty_opbundle(%arg0 : !llvm.ptr) { + %0 = llvm.mlir.constant(1 : i1) : i1 + // CHECK: llvm.intr.assume %0 : i1 + llvm.intr.assume %0 [] : i1 + llvm.return +} + +// CHECK-LABEL: @test_assume_intr_with_opbundles +llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) { + %0 = llvm.mlir.constant(1 : i1) : i1 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + %3 = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1 + llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 2fc2c3c6c32ffa..abb59ba8637d9f 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -612,11 +612,21 @@ define void @va_intrinsics_test(ptr %0, ptr %1, ...) { ; CHECK-LABEL: @assume ; CHECK-SAME: %[[TRUE:[a-zA-Z0-9]+]] define void @assume(i1 %true) { - ; CHECK: "llvm.intr.assume"(%[[TRUE]]) : (i1) -> () + ; CHECK: llvm.intr.assume %[[TRUE]] : i1 call void @llvm.assume(i1 %true) ret void } +; CHECK-LABEL: @assume_with_opbundles +; CHECK-SAME: %[[TRUE:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] +define void @assume_with_opbundles(i1 %true, ptr %p) { + ; CHECK: %[[ALIGN:.+]] = llvm.mlir.constant(8 : i32) : i32 + ; CHECK: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1 + call void @llvm.assume(i1 %true) ["align"(ptr %p, i32 8)] + ret void +} + ; CHECK-LABEL: @is_constant ; CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]] define void @is_constant(i32 %0) { diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index de0dc8d21584fe..1ee4988fb55185 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -344,6 +344,21 @@ llvm.func @umin_test(%arg0: i32, %arg1: i32, %arg2: vector<8xi32>, %arg3: vector llvm.return } +// CHECK-LABEL: @assume_without_opbundles +llvm.func @assume_without_opbundles(%cond: i1) { + // CHECK: call void @llvm.assume(i1 %{{.+}}) + llvm.intr.assume %cond : i1 + llvm.return +} + +// CHECK-LABEL: @assume_with_opbundles +llvm.func @assume_with_opbundles(%cond: i1, %p: !llvm.ptr) { + %0 = llvm.mlir.constant(8 : i32) : i32 + // CHECK: call void @llvm.assume(i1 %{{.+}}) [ "align"(ptr %{{.+}}, i32 8) ] + llvm.intr.assume %cond ["align"(%p, %0 : !llvm.ptr, i32)] : i1 + llvm.return +} + // CHECK-LABEL: @vector_reductions llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi32>) { // CHECK: call i32 @llvm.vector.reduce.add.v8i32 diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index af0981440a1776..15658ea6068121 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -188,7 +188,7 @@ llvm.func @sadd_overflow_intr_wrong_type(%arg0 : i32, %arg1 : f32) -> !llvm.stru llvm.func @assume_intr_wrong_type(%cond : i16) { // expected-error @below{{op operand #0 must be 1-bit signless integer, but got 'i16'}} - "llvm.intr.assume"(%cond) : (i16) -> () + llvm.intr.assume %cond : i16 llvm.return }