-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[RFC][mlir] Conditional support for fast-math attributes. #125620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-fir-hlfir Author: Slava Zakharin (vzakhari) ChangesThis patch suggests changes for operations that support This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types The changes add new isArithFastMathApplicable/isFastmathApplicable LLVM dialect isFastmathApplicable implementation is based on llvm-project/llvm/include/llvm/IR/Operator.h Line 380 in bac62ee
ARITH dialect isArithFastMathApplicable is more relaxed, because it has to support custom MLIR types. This is the area where improvements are needed (see TODO comments). I will appreciate feedback here. HLFIR dialect is a another example where conditional fast-math support may be applied currently. Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff 17 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
+
+ /// Always allow FastMathFlags for fir.call's.
+ /// It is required to be able to propagate the call site's
+ /// FastMathFlags to the operations resulting from inlining
+ /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+ /// We could analyze the arguments' data types to see if there are
+ /// any floating point types, but this is unreliable. For example,
+ /// the runtime calls mostly take !fir.box<none> arguments,
+ /// and tracking them to the definitions may be not easy.
+ /// TODO: this should be restricted to fir.runtime calls,
+ /// because FastMathFlags for the user calls must come
+ /// from the function body, not the call site.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+ /// Always allow FastMathFlags on fir.cmpc.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+ // FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
- if (fmi) {
- // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
- // For now set the attribute by the name.
+ if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(),
+ auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+ call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ auto fmi =
+ mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+ if (!fmi.isFastmathApplicable())
+ llvmCall->setAttr(
+ mlir::LLVM::CallOp::getFastmathAttrName(),
+ mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+ mlir::LLVM::FastmathFlags::none));
+ rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+ if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+ if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+
+ return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
- %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+ %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
+ let extraClassDeclaration = [{
+ bool isApplicable() { return true; }
+ }];
let hasVerifier = 1;
let hasFolder = 1;
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastMathFlags on arith.cmpf.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let cppNamespace = "::mlir::arith";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
- /*returnType=*/ "FastMathFlagsAttr",
- /*methodName=*/ "getFastMathFlagsAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+ /*returnType=*/"FastMathFlagsAttr",
+ /*methodName=*/"getFastMathFlagsAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "StringRef",
- /*methodName=*/ "getFastMathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"StringRef",
+ /*methodName=*/"getFastMathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmath";
- }]
- >
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+ is applicable to the operation that supports
+ ArithFastMathInterface. If it returns false,
+ then the FastMathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isArithFastMathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
- ];
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isArithFastMathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+ auto attr = fmi.getFastMathFlagsAttr();
+ if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+ !fmi.isArithFastMathApplicable())
+ return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
- /*returnType=*/ "::mlir::LLVM::FastmathFlagsAttr",
- /*methodName=*/ "getFastmathAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+ /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+ /*methodName=*/"getFastmathAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "::llvm::StringRef",
- /*methodName=*/ "getFastmathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"::llvm::StringRef",
+ /*methodName=*/"getFastmathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmathFlags";
- }]
- >
- ];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+ is applicable to the operation that supports
+ FastmathInterface. If it returns false,
+ then the FastmathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isFastmathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point typ...
[truncated]
|
@llvm/pr-subscribers-mlir-arith Author: Slava Zakharin (vzakhari) ChangesThis patch suggests changes for operations that support This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types The changes add new isArithFastMathApplicable/isFastmathApplicable LLVM dialect isFastmathApplicable implementation is based on llvm-project/llvm/include/llvm/IR/Operator.h Line 380 in bac62ee
ARITH dialect isArithFastMathApplicable is more relaxed, because it has to support custom MLIR types. This is the area where improvements are needed (see TODO comments). I will appreciate feedback here. HLFIR dialect is a another example where conditional fast-math support may be applied currently. Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff 17 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
+
+ /// Always allow FastMathFlags for fir.call's.
+ /// It is required to be able to propagate the call site's
+ /// FastMathFlags to the operations resulting from inlining
+ /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+ /// We could analyze the arguments' data types to see if there are
+ /// any floating point types, but this is unreliable. For example,
+ /// the runtime calls mostly take !fir.box<none> arguments,
+ /// and tracking them to the definitions may be not easy.
+ /// TODO: this should be restricted to fir.runtime calls,
+ /// because FastMathFlags for the user calls must come
+ /// from the function body, not the call site.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+ /// Always allow FastMathFlags on fir.cmpc.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+ // FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
- if (fmi) {
- // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
- // For now set the attribute by the name.
+ if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(),
+ auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+ call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ auto fmi =
+ mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+ if (!fmi.isFastmathApplicable())
+ llvmCall->setAttr(
+ mlir::LLVM::CallOp::getFastmathAttrName(),
+ mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+ mlir::LLVM::FastmathFlags::none));
+ rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+ if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+ if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+
+ return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
- %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+ %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
+ let extraClassDeclaration = [{
+ bool isApplicable() { return true; }
+ }];
let hasVerifier = 1;
let hasFolder = 1;
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastMathFlags on arith.cmpf.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let cppNamespace = "::mlir::arith";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
- /*returnType=*/ "FastMathFlagsAttr",
- /*methodName=*/ "getFastMathFlagsAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+ /*returnType=*/"FastMathFlagsAttr",
+ /*methodName=*/"getFastMathFlagsAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "StringRef",
- /*methodName=*/ "getFastMathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"StringRef",
+ /*methodName=*/"getFastMathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmath";
- }]
- >
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+ is applicable to the operation that supports
+ ArithFastMathInterface. If it returns false,
+ then the FastMathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isArithFastMathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
- ];
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isArithFastMathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+ auto attr = fmi.getFastMathFlagsAttr();
+ if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+ !fmi.isArithFastMathApplicable())
+ return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
- /*returnType=*/ "::mlir::LLVM::FastmathFlagsAttr",
- /*methodName=*/ "getFastmathAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+ /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+ /*methodName=*/"getFastmathAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "::llvm::StringRef",
- /*methodName=*/ "getFastmathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"::llvm::StringRef",
+ /*methodName=*/"getFastmathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmathFlags";
- }]
- >
- ];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+ is applicable to the operation that supports
+ FastmathInterface. If it returns false,
+ then the FastmathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isFastmathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point typ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Slava Zakharin (vzakhari) ChangesThis patch suggests changes for operations that support This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types The changes add new isArithFastMathApplicable/isFastmathApplicable LLVM dialect isFastmathApplicable implementation is based on llvm-project/llvm/include/llvm/IR/Operator.h Line 380 in bac62ee
ARITH dialect isArithFastMathApplicable is more relaxed, because it has to support custom MLIR types. This is the area where improvements are needed (see TODO comments). I will appreciate feedback here. HLFIR dialect is a another example where conditional fast-math support may be applied currently. Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff 17 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
+
+ /// Always allow FastMathFlags for fir.call's.
+ /// It is required to be able to propagate the call site's
+ /// FastMathFlags to the operations resulting from inlining
+ /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+ /// We could analyze the arguments' data types to see if there are
+ /// any floating point types, but this is unreliable. For example,
+ /// the runtime calls mostly take !fir.box<none> arguments,
+ /// and tracking them to the definitions may be not easy.
+ /// TODO: this should be restricted to fir.runtime calls,
+ /// because FastMathFlags for the user calls must come
+ /// from the function body, not the call site.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+ /// Always allow FastMathFlags on fir.cmpc.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+ // FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
- if (fmi) {
- // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
- // For now set the attribute by the name.
+ if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(),
+ auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+ call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ auto fmi =
+ mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+ if (!fmi.isFastmathApplicable())
+ llvmCall->setAttr(
+ mlir::LLVM::CallOp::getFastmathAttrName(),
+ mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+ mlir::LLVM::FastmathFlags::none));
+ rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+ if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+ if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+
+ return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
- %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+ %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
+ let extraClassDeclaration = [{
+ bool isApplicable() { return true; }
+ }];
let hasVerifier = 1;
let hasFolder = 1;
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastMathFlags on arith.cmpf.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let cppNamespace = "::mlir::arith";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
- /*returnType=*/ "FastMathFlagsAttr",
- /*methodName=*/ "getFastMathFlagsAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+ /*returnType=*/"FastMathFlagsAttr",
+ /*methodName=*/"getFastMathFlagsAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "StringRef",
- /*methodName=*/ "getFastMathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"StringRef",
+ /*methodName=*/"getFastMathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmath";
- }]
- >
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+ is applicable to the operation that supports
+ ArithFastMathInterface. If it returns false,
+ then the FastMathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isArithFastMathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
- ];
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isArithFastMathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+ auto attr = fmi.getFastMathFlagsAttr();
+ if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+ !fmi.isArithFastMathApplicable())
+ return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
- /*returnType=*/ "::mlir::LLVM::FastmathFlagsAttr",
- /*methodName=*/ "getFastmathAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+ /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+ /*methodName=*/"getFastmathAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "::llvm::StringRef",
- /*methodName=*/ "getFastmathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"::llvm::StringRef",
+ /*methodName=*/"getFastmathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmathFlags";
- }]
- >
- ];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+ is applicable to the operation that supports
+ FastmathInterface. If it returns false,
+ then the FastmathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isFastmathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point typ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and my goal to add fast-math support for arith.select operation
Why would we want to have fast math flags over arith.select
? What optimizations / rewrites does this allow?
For example, it enables vectorization of loops with min/max reductions in LLVM. Flang is currently producing In general, in LLVM any instruction that produces a floating point result may have fast-math flags. This includes FP PHIs and selects. |
})) | ||
return true; | ||
|
||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change this to false
after fixing the lowering tests.
auto attr = fmi.getFastMathFlagsAttr(); | ||
if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none && | ||
!fmi.isArithFastMathApplicable()) | ||
return $_op->emitOpError() << "FastMathFlagsAttr is not applicable"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return $_op->emitOpError() << "FastMathFlagsAttr is not applicable"; | |
return $_op->emitOpError() << "has flag " << stringify(attr.getValue()) << " but fast-math flags are not applicable (`isArithFastMathApplicable()` returns false)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Will apply.
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast | |||
The destination type must to be strictly wider than the source type. | |||
When operating on vectors, casts elementwise. | |||
}]; | |||
let extraClassDeclaration = [{ | |||
bool isApplicable() { return true; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this method called?
Did you mean isArithFastMathApplicable()
here? (if so we're missing a test to cover this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a leftover. There is no need to override isArithFastMathApplicable
for arith.extf
, because it has a floating point result.
auto fmi = | ||
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation()); | ||
if (!fmi.isFastmathApplicable()) | ||
llvmCall->setAttr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be better accessor on the LLVM::CallOp (this one is generic and quite expensive): ODS generates an accessor per-attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Will fix.
then the FastMathFlagsAttr of the operation | ||
must be nullptr or have 'none' value}], | ||
/*returnType=*/"bool", | ||
/*methodName=*/"isArithFastMathApplicable", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I think of this as a sort of "verifier" for fastMath flags for the given operation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its intention is to tell whether fast-math flags are applicable. It is used in the verified code below, but it may also be used by the passes/builders the create new operations supporting ArithFastMathInterface
, e.g. see its usage in FIRBuilder.cpp
file above.
let extraClassDeclaration = [{ | ||
/// Always allow FastmathFlags on llvm.fcmp. | ||
/// It does not produce a floating point result, but | ||
/// LLVM is currently relying on fast-math flags attached |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean that LLVM will look into the compare instead of the select operation for fastMath? Is unclear to me why you are changing this method for the cmpi and not the select.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM can look at both compare and select, depending on what it needs to do.
What I mean here is: LLVM's fcmp
instruction supports fast-math flags, and llvm.fcmp
operation should also support them; the general rule for instructions/operations to support fast-math is that they produce a floating point result; neither fcmp
nor llvm.fcmp
produce floating point result, so they are exceptions from the general rule, so isFastmathApplicable
should be overridden here.
There is not need to override isFastmathApplicable
for llvm.select
, because it is covered by the general rule.
Note that the comment is explicitly saying that this is a temporary solution while LLVM expects it.
LLVM code has the following TODO about fcmp
:
// FIXME: To clean up and correct the semantics of fast-math-flags, FCmp
// should not be treated as a math op, but the other opcodes should.
// This would make things consistent with Select/PHI (FP value type
// determines whether they are math ops and, therefore, capable of
// having fast-math-flags).
Can you link to some example that shows why this is necessary? I'd think that |
I was asking myself the same question. I guess if an |
This could be supported by a dedicated unary op. |
I am sorry, I do not have examples readily available for you. There is LLVM Floating Point Working Group (https://discourse.llvm.org/t/floating-point-working-group/76907/10) that discussed the need for fast-math flags on The initial addition of fast-math support for |
I think we could do something like this:
|
Can you please explain how this is better than having fast-math flags on the select itself? It seems that what you are proposing depends on whether the operands' Please also note that this patch does not add fast-math support to |
FYI, I addressed the review comments in 8834e36, but github does not show it here for some reason. I see |
Sorry, I was mostly thinking out loud and did not mean to derail this PR. I'm trying to understand the goal stated in the PR description:
I'd not expect
I'd think that the real implementation can have a matcher that checks if |
Thanks for the explanation!
I am not against the |
This patch suggests changes for operations that support arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface. Some of the operations may have fast-math flags not equal to `none` only if they operate on floating point values. This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types and my goal to add fast-math support for `arith.select` operation that may produce results of any type. The changes add new isArithFastMathApplicable/isFastmathApplicable methods to the above interfaces that tell whether an operation supporting the interface may have non-none fast-math flags. LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380 ARITH dialect isArithFastMathApplicable is more relaxed, because it has to support custom MLIR types. This is the area where improvements are needed (see TODO comments). I will appreciate feedback here. HLFIR dialect is a another example where conditional fast-math support may be applied currently.
8834e36
to
a517b83
Compare
return true; | ||
|
||
// TODO: what about TupleType and custom dialect struct-like types? | ||
// It seems that they worth an interface to get to the list of element types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a use case of an operation with the fastmath interface that takes/returns such type?
If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.
That said, such type interface would make some sense to me in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mostly thinking about arith.select
, but it seems it will be a separate discussion and changes (if any).
If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.
This might be an acceptable approach. I was thinking about making the other dialects' life easier by handling it here, but I can postpone the decision until the need arises.
/// TODO: the results often have the same type, and traversing | ||
/// the same type again and again is not very efficient. | ||
/// We can cache it here for the duration of the processing. | ||
/// Other ideas? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about the result of a same operation with multiple results, or the result of different operation.
If this is about the former, it seems to ne the caching would be overkill given the average number or results in operation. If this is about the later, it seems to me that maintaining some shared cache somewhere would not be cheap, is the call to isCompatibleType
that expensive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking about the results of the same operation, especially, for the case of struct-like types that may have nested types... I agree with you that in the current state of isCompatibleType
it does not make sence to do any caching. I will remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vzakhari, the flang dialect changes looks good and this makes sense to me in general. Please wait for the approval from those who had comments.
return isCompatibleType(shapedType.getElementType()); | ||
|
||
// ComplexType's element type is always a FloatType. | ||
if (auto complexType = dyn_cast<ComplexType>(type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move in isa with FloatType?
/// TODO: the results often have the same type, and traversing | ||
/// the same type again and again is not very efficient. | ||
/// We can cache it here for the duration of the processing. | ||
/// Other ideas? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about the result of a same operation with multiple results, or the result of different operation.
If this is about the former, it seems to ne the caching would be overkill given the average number or results in operation. If this is about the later, it seems to me that maintaining some shared cache somewhere would not be cheap, is the call to isCompatibleType
that expensive?
return true; | ||
|
||
// TODO: what about TupleType and custom dialect struct-like types? | ||
// It seems that they worth an interface to get to the list of element types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a use case of an operation with the fastmath interface that takes/returns such type?
If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.
That said, such type interface would make some sense to me in general.
This feels very odd to me - like an implementation detail of LLVM leaking way higher up into the stack than it should. |
I am not sure I understand how fast-math attributes attached to
Isn't it true for any other discardable attribute? In my opinion, the transformations should drop any attributes that they don't know about or preserve any attribute that they know about (generic attribute interfaces may ease the handling here). Otherwise, how we can guarantee that a discardable attribute propagated during op->op transformation, without paying attention to preserving the attribute semantics, is still valid in the MLIR after the transformation? I created https://discourse.llvm.org/t/rfc-arithfastmathinterface-support-for-arith-select/84508 for further discussion of the fast-math flags on I think the discussion here should be about conditional support of fast-math attributes on operations that support the fast-math interface. I believe the HLFIR case is a good example where such support makes sence.
I do not think it is always possible (see example in the above discourse RFC). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comments, Jean. I will upload updated files shortly.
/// TODO: the results often have the same type, and traversing | ||
/// the same type again and again is not very efficient. | ||
/// We can cache it here for the duration of the processing. | ||
/// Other ideas? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking about the results of the same operation, especially, for the case of struct-like types that may have nested types... I agree with you that in the current state of isCompatibleType
it does not make sence to do any caching. I will remove the comment.
return true; | ||
|
||
// TODO: what about TupleType and custom dialect struct-like types? | ||
// It seems that they worth an interface to get to the list of element types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mostly thinking about arith.select
, but it seems it will be a separate discussion and changes (if any).
If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.
This might be an acceptable approach. I was thinking about making the other dialects' life easier by handling it here, but I can postpone the decision until the need arises.
This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to
none
only if they operate on floating point values.
This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for
arith.select
operationthat may produce results of any type.
The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.
LLVM dialect isFastmathApplicable implementation is based on
llvm-project/llvm/include/llvm/IR/Operator.h
Line 380 in bac62ee
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.