Skip to content

[RFC][mlir] Conditional support for fast-math attributes. #125620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Feb 4, 2025

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to none
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for arith.select operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on

return isSupportedFloatingPointType(V->getType());

ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to none
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for arith.select operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on

return isSupportedFloatingPointType(V->getType());

ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.


Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff

17 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+26)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+5)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+54)
  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-3)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+10-2)
  • (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+17)
  • (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1)
  • (modified) flang/test/Fir/tbaa.fir (+3-3)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+14)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+49-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+50-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithDialect.cpp (+47)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+40)
  • (modified) mlir/test/Dialect/LLVMIR/inlining.mlir (+6-6)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+16-10)
  • (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+3-3)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
                          llvm::cast<mlir::SymbolRefAttr>(callee));
       setOperand(0, llvm::cast<mlir::Value>(callee));
     }
+
+    /// Always allow FastMathFlags for fir.call's.
+    /// It is required to be able to propagate the call site's
+    /// FastMathFlags to the operations resulting from inlining
+    /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+    /// We could analyze the arguments' data types to see if there are
+    /// any floating point types, but this is unreliable. For example,
+    /// the runtime calls mostly take !fir.box<none> arguments,
+    /// and tracking them to the definitions may be not easy.
+    /// TODO: this should be restricted to fir.runtime calls,
+    /// because FastMathFlags for the user calls must come
+    /// from the function body, not the call site.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
     }
 
     static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+    /// Always allow FastMathFlags on fir.cmpc.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
     static bool isPointerCompatible(mlir::Type ty);
     static bool canBeConverted(mlir::Type inType, mlir::Type outType);
     static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+    // FIXME: fir.convert should support ArithFastMathInterface.
   }];
   let hasCanonicalizer = 1;
 }
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
 /// Scalar integer or a sequence of integers (via boxed array or expr).
 bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
 
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
   let hasCanonicalizeMethod = 1;
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
 
 void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
-  if (fmi) {
-    // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
-    //       For now set the attribute by the name.
+  if (fmi && fmi.isArithFastMathApplicable()) {
     llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
     if (fastMathFlags != mlir::arith::FastMathFlags::none)
       op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(),
+    auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+        call.getLoc(), resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    auto fmi =
+        mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+    if (!fmi.isFastmathApplicable())
+      llvmCall->setAttr(
+          mlir::LLVM::CallOp::getFastmathAttrName(),
+          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+                                             mlir::LLVM::FastmathFlags::none));
+    rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
   mlir::Type elementType = getFortranElementType(unwrappedType);
   return mlir::isa<mlir::IntegerType>(elementType);
 }
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+  if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+  if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+
+  return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     %45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
     gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
     %46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
-    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
     %48 = llvm.mlir.constant(9 : i32) : i32
     %49 = llvm.mlir.zero : !llvm.ptr
     %50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
 // CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
 // CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
 // CHECK:           llvm.return
 // CHECK:         }
 // CHECK:         llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
+  let extraClassDeclaration = [{
+    bool isApplicable() { return true; }
+  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
   let hasCanonicalizer = 1;
   let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs)}];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastMathFlags on arith.cmpf.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
 
   let cppNamespace = "::mlir::arith";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
-      /*returnType=*/  "FastMathFlagsAttr",
-      /*methodName=*/  "getFastMathFlagsAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+           /*returnType=*/"FastMathFlagsAttr",
+           /*methodName=*/"getFastMathFlagsAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getFastMathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"StringRef",
+           /*methodName=*/"getFastMathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmath";
-      }]
-      >
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         ArithFastMathInterface. If it returns false,
+                         then the FastMathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isArithFastMathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
 
-  ];
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isArithFastMathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+    auto attr = fmi.getFastMathFlagsAttr();
+    if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+        !fmi.isArithFastMathApplicable())
+      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
   let cppNamespace = "::mlir::LLVM";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
-      /*returnType=*/  "::mlir::LLVM::FastmathFlagsAttr",
-      /*methodName=*/  "getFastmathAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+           /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+           /*methodName=*/"getFastmathAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "::llvm::StringRef",
-      /*methodName=*/  "getFastmathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"::llvm::StringRef",
+           /*methodName=*/"getFastmathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmathFlags";
-      }]
-      >
-  ];
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         FastmathInterface. If it returns false,
+                         then the FastmathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isFastmathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
+
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point typ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-arith

Author: Slava Zakharin (vzakhari)

Changes

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to none
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for arith.select operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on

return isSupportedFloatingPointType(V->getType());

ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.


Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff

17 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+26)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+5)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+54)
  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-3)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+10-2)
  • (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+17)
  • (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1)
  • (modified) flang/test/Fir/tbaa.fir (+3-3)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+14)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+49-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+50-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithDialect.cpp (+47)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+40)
  • (modified) mlir/test/Dialect/LLVMIR/inlining.mlir (+6-6)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+16-10)
  • (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+3-3)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
                          llvm::cast<mlir::SymbolRefAttr>(callee));
       setOperand(0, llvm::cast<mlir::Value>(callee));
     }
+
+    /// Always allow FastMathFlags for fir.call's.
+    /// It is required to be able to propagate the call site's
+    /// FastMathFlags to the operations resulting from inlining
+    /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+    /// We could analyze the arguments' data types to see if there are
+    /// any floating point types, but this is unreliable. For example,
+    /// the runtime calls mostly take !fir.box<none> arguments,
+    /// and tracking them to the definitions may be not easy.
+    /// TODO: this should be restricted to fir.runtime calls,
+    /// because FastMathFlags for the user calls must come
+    /// from the function body, not the call site.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
     }
 
     static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+    /// Always allow FastMathFlags on fir.cmpc.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
     static bool isPointerCompatible(mlir::Type ty);
     static bool canBeConverted(mlir::Type inType, mlir::Type outType);
     static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+    // FIXME: fir.convert should support ArithFastMathInterface.
   }];
   let hasCanonicalizer = 1;
 }
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
 /// Scalar integer or a sequence of integers (via boxed array or expr).
 bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
 
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
   let hasCanonicalizeMethod = 1;
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
 
 void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
-  if (fmi) {
-    // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
-    //       For now set the attribute by the name.
+  if (fmi && fmi.isArithFastMathApplicable()) {
     llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
     if (fastMathFlags != mlir::arith::FastMathFlags::none)
       op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(),
+    auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+        call.getLoc(), resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    auto fmi =
+        mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+    if (!fmi.isFastmathApplicable())
+      llvmCall->setAttr(
+          mlir::LLVM::CallOp::getFastmathAttrName(),
+          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+                                             mlir::LLVM::FastmathFlags::none));
+    rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
   mlir::Type elementType = getFortranElementType(unwrappedType);
   return mlir::isa<mlir::IntegerType>(elementType);
 }
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+  if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+  if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+
+  return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     %45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
     gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
     %46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
-    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
     %48 = llvm.mlir.constant(9 : i32) : i32
     %49 = llvm.mlir.zero : !llvm.ptr
     %50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
 // CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
 // CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
 // CHECK:           llvm.return
 // CHECK:         }
 // CHECK:         llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
+  let extraClassDeclaration = [{
+    bool isApplicable() { return true; }
+  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
   let hasCanonicalizer = 1;
   let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs)}];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastMathFlags on arith.cmpf.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
 
   let cppNamespace = "::mlir::arith";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
-      /*returnType=*/  "FastMathFlagsAttr",
-      /*methodName=*/  "getFastMathFlagsAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+           /*returnType=*/"FastMathFlagsAttr",
+           /*methodName=*/"getFastMathFlagsAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getFastMathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"StringRef",
+           /*methodName=*/"getFastMathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmath";
-      }]
-      >
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         ArithFastMathInterface. If it returns false,
+                         then the FastMathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isArithFastMathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
 
-  ];
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isArithFastMathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+    auto attr = fmi.getFastMathFlagsAttr();
+    if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+        !fmi.isArithFastMathApplicable())
+      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
   let cppNamespace = "::mlir::LLVM";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
-      /*returnType=*/  "::mlir::LLVM::FastmathFlagsAttr",
-      /*methodName=*/  "getFastmathAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+           /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+           /*methodName=*/"getFastmathAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "::llvm::StringRef",
-      /*methodName=*/  "getFastmathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"::llvm::StringRef",
+           /*methodName=*/"getFastmathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmathFlags";
-      }]
-      >
-  ];
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         FastmathInterface. If it returns false,
+                         then the FastmathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isFastmathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
+
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point typ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir

Author: Slava Zakharin (vzakhari)

Changes

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to none
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for arith.select operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on

return isSupportedFloatingPointType(V->getType());

ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.


Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff

17 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+26)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+5)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+54)
  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-3)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+10-2)
  • (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+17)
  • (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1)
  • (modified) flang/test/Fir/tbaa.fir (+3-3)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+14)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+49-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+50-20)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithDialect.cpp (+47)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+40)
  • (modified) mlir/test/Dialect/LLVMIR/inlining.mlir (+6-6)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+16-10)
  • (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+3-3)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
                          llvm::cast<mlir::SymbolRefAttr>(callee));
       setOperand(0, llvm::cast<mlir::Value>(callee));
     }
+
+    /// Always allow FastMathFlags for fir.call's.
+    /// It is required to be able to propagate the call site's
+    /// FastMathFlags to the operations resulting from inlining
+    /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+    /// We could analyze the arguments' data types to see if there are
+    /// any floating point types, but this is unreliable. For example,
+    /// the runtime calls mostly take !fir.box<none> arguments,
+    /// and tracking them to the definitions may be not easy.
+    /// TODO: this should be restricted to fir.runtime calls,
+    /// because FastMathFlags for the user calls must come
+    /// from the function body, not the call site.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
     }
 
     static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+    /// Always allow FastMathFlags on fir.cmpc.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
     static bool isPointerCompatible(mlir::Type ty);
     static bool canBeConverted(mlir::Type inType, mlir::Type outType);
     static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+    // FIXME: fir.convert should support ArithFastMathInterface.
   }];
   let hasCanonicalizer = 1;
 }
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
 /// Scalar integer or a sequence of integers (via boxed array or expr).
 bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
 
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
   let hasCanonicalizeMethod = 1;
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
 
 void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
-  if (fmi) {
-    // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
-    //       For now set the attribute by the name.
+  if (fmi && fmi.isArithFastMathApplicable()) {
     llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
     if (fastMathFlags != mlir::arith::FastMathFlags::none)
       op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(),
+    auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+        call.getLoc(), resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    auto fmi =
+        mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+    if (!fmi.isFastmathApplicable())
+      llvmCall->setAttr(
+          mlir::LLVM::CallOp::getFastmathAttrName(),
+          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+                                             mlir::LLVM::FastmathFlags::none));
+    rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
   mlir::Type elementType = getFortranElementType(unwrappedType);
   return mlir::isa<mlir::IntegerType>(elementType);
 }
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+  if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+  if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+
+  return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     %45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
     gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
     %46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
-    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
     %48 = llvm.mlir.constant(9 : i32) : i32
     %49 = llvm.mlir.zero : !llvm.ptr
     %50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
 // CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
 // CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
 // CHECK:           llvm.return
 // CHECK:         }
 // CHECK:         llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
+  let extraClassDeclaration = [{
+    bool isApplicable() { return true; }
+  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
   let hasCanonicalizer = 1;
   let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs)}];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastMathFlags on arith.cmpf.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
 
   let cppNamespace = "::mlir::arith";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
-      /*returnType=*/  "FastMathFlagsAttr",
-      /*methodName=*/  "getFastMathFlagsAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+           /*returnType=*/"FastMathFlagsAttr",
+           /*methodName=*/"getFastMathFlagsAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getFastMathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"StringRef",
+           /*methodName=*/"getFastMathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmath";
-      }]
-      >
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         ArithFastMathInterface. If it returns false,
+                         then the FastMathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isArithFastMathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
 
-  ];
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isArithFastMathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+    auto attr = fmi.getFastMathFlagsAttr();
+    if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+        !fmi.isArithFastMathApplicable())
+      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
   let cppNamespace = "::mlir::LLVM";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
-      /*returnType=*/  "::mlir::LLVM::FastmathFlagsAttr",
-      /*methodName=*/  "getFastmathAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+           /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+           /*methodName=*/"getFastmathAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "::llvm::StringRef",
-      /*methodName=*/  "getFastmathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"::llvm::StringRef",
+           /*methodName=*/"getFastmathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmathFlags";
-      }]
-      >
-  ];
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         FastmathInterface. If it returns false,
+                         then the FastmathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isFastmathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
+
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point typ...
[truncated]

@kuhar kuhar requested review from kuhar, dcaballe and chelini February 4, 2025 02:27
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and my goal to add fast-math support for arith.select operation

Why would we want to have fast math flags over arith.select? What optimizations / rewrites does this allow?

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 4, 2025

and my goal to add fast-math support for arith.select operation

Why would we want to have fast math flags over arith.select? What optimizations / rewrites does this allow?

For example, it enables vectorization of loops with min/max reductions in LLVM. Flang is currently producing arith.select without fast-math attrs.

In general, in LLVM any instruction that produces a floating point result may have fast-math flags. This includes FP PHIs and selects.

}))
return true;

return true;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change this to false after fixing the lowering tests.

auto attr = fmi.getFastMathFlagsAttr();
if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
!fmi.isArithFastMathApplicable())
return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
return $_op->emitOpError() << "has flag " << stringify(attr.getValue()) << " but fast-math flags are not applicable (`isArithFastMathApplicable()` returns false)";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will apply.

@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
let extraClassDeclaration = [{
bool isApplicable() { return true; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this method called?

Did you mean isArithFastMathApplicable() here? (if so we're missing a test to cover this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a leftover. There is no need to override isArithFastMathApplicable for arith.extf, because it has a floating point result.

auto fmi =
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
if (!fmi.isFastmathApplicable())
llvmCall->setAttr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be better accessor on the LLVM::CallOp (this one is generic and quite expensive): ODS generates an accessor per-attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Will fix.

then the FastMathFlagsAttr of the operation
must be nullptr or have 'none' value}],
/*returnType=*/"bool",
/*methodName=*/"isArithFastMathApplicable",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I think of this as a sort of "verifier" for fastMath flags for the given operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its intention is to tell whether fast-math flags are applicable. It is used in the verified code below, but it may also be used by the passes/builders the create new operations supporting ArithFastMathInterface, e.g. see its usage in FIRBuilder.cpp file above.

let extraClassDeclaration = [{
/// Always allow FastmathFlags on llvm.fcmp.
/// It does not produce a floating point result, but
/// LLVM is currently relying on fast-math flags attached
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean that LLVM will look into the compare instead of the select operation for fastMath? Is unclear to me why you are changing this method for the cmpi and not the select.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM can look at both compare and select, depending on what it needs to do.

What I mean here is: LLVM's fcmp instruction supports fast-math flags, and llvm.fcmp operation should also support them; the general rule for instructions/operations to support fast-math is that they produce a floating point result; neither fcmp nor llvm.fcmp produce floating point result, so they are exceptions from the general rule, so isFastmathApplicable should be overridden here.

There is not need to override isFastmathApplicable for llvm.select, because it is covered by the general rule.

Note that the comment is explicitly saying that this is a temporary solution while LLVM expects it.

LLVM code has the following TODO about fcmp:

    // FIXME: To clean up and correct the semantics of fast-math-flags, FCmp
    //        should not be treated as a math op, but the other opcodes should.
    //        This would make things consistent with Select/PHI (FP value type
    //        determines whether they are math ops and, therefore, capable of
    //        having fast-math-flags).

@kuhar
Copy link
Member

kuhar commented Feb 4, 2025

For example, it enables vectorization of loops with min/max reductions in LLVM. Flang is currently producing arith.select without fast-math attrs.

In general, in LLVM any instruction that produces a floating point result may have fast-math flags. This includes FP PHIs and selects.

Can you link to some example that shows why this is necessary? I'd think that arith.select preserves bitpatterns of its operands, so I struggle to see why it needs fast math flags; if llvm needs them, couldn't it calculate them as the intersection of the fast math flags of the operands?

@dcaballe
Copy link
Contributor

dcaballe commented Feb 4, 2025

For example, it enables vectorization of loops with min/max reductions in LLVM. Flang is currently producing arith.select without fast-math attrs.
In general, in LLVM any instruction that produces a floating point result may have fast-math flags. This includes FP PHIs and selects.

Can you link to some example that shows why this is necessary? I'd think that arith.select preserves bitpatterns of its operands, so I struggle to see why it needs fast math flags; if llvm needs them, couldn't it calculate them as the intersection of the fast math flags of the operands?

I was asking myself the same question. I guess if an arith.select has nnan and one of the inputs is NaN we could turn it into poison...

@kuhar
Copy link
Member

kuhar commented Feb 4, 2025

I was asking myself the same question. I guess if an arith.select has nnan and one of the inputs is NaN we could turn it into poison...

This could be supported by a dedicated unary op.

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 4, 2025

Can you link to some example that shows why this is necessary? I'd think that arith.select preserves bitpatterns of its operands, so I struggle to see why it needs fast math flags; if llvm needs them, couldn't it calculate them as the intersection of the fast math flags of the operands?

I am sorry, I do not have examples readily available for you.

There is LLVM Floating Point Working Group (https://discourse.llvm.org/t/floating-point-working-group/76907/10) that discussed the need for fast-math flags on select instructions here: #51601 (see also notes from February 21, 2024 in https://docs.google.com/document/d/1QcmUlWftPlBi-Wz6b6PipqJfvjpJ-OuRMRnN9Dm2t0c/edit?tab=t.0#heading=h.k3uvggph248w

The initial addition of fast-math support for select was done in 2019: 5a4f7cf

@kuhar
Copy link
Member

kuhar commented Feb 4, 2025

I think we could do something like this:

  • add a new op, say %y = arith.assumef %x fastmath<nnan> : f32 (or ub.assumef?) whose sole purpose is to apply fast math flags to its operand
  • teach arith to llvm conversion to convert arith.select(arith.assumef(x), arith.assumef(y)) to apply fast math flags to the produced llvm.select (and potentially other ops)
  • teach the frontend to produce these arith.assumef ops when emitting selects with fast math enabled

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 4, 2025

I think we could do something like this:

  • add a new op, say %y = arith.assumef %x fastmath<nnan> : f32 (or ub.assumef?) whose sole purpose is to apply fast math flags to its operand
  • teach arith to llvm conversion to convert arith.select(arith.assumef(x), arith.assumef(y)) to apply fast math flags to the produced llvm.select (and potentially other ops)
  • teach the frontend to produce these arith.assumef ops when emitting selects with fast math enabled

Can you please explain how this is better than having fast-math flags on the select itself? It seems that what you are proposing depends on whether the operands' arith.assume definitions are always reachable from arith.select, but it may not be true with block arguments. So having arith.select to carry fast-math seems more robust to me.

Please also note that this patch does not add fast-math support to arith.select, it just adds support for conditional fast-math support in general and for operations of HLFIR dialect. I appreciate the discussion about arith.select, but I just want to understand if you are also proposing some changes to the current patch - can you please clarify?

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 4, 2025

FYI, I addressed the review comments in 8834e36, but github does not show it here for some reason. I see Processing updates idle spin at the top of this page...

@kuhar
Copy link
Member

kuhar commented Feb 4, 2025

Please also note that this patch does not add fast-math support to arith.select, it just adds support for conditional fast-math support in general and for operations of HLFIR dialect. I appreciate the discussion about arith.select, but I just want to understand if you are also proposing some changes to the current patch - can you please clarify?

Sorry, I was mostly thinking out loud and did not mean to derail this PR. I'm trying to understand the goal stated in the PR description:

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types and my goal to add fast-math support for arith.select operation that may produce results of any type.

I'd not expect arith.select to know anything about fast math because of the reasons mentioned above in #125620 (comment). Separately, fast math in clang/llvm is known to be broken and I'd like to make sure we are not working backwards from the solution and invertedly repeat the same mistakes. For example, see this issue: iree-org/iree#19743. I will set aside some time to read the llvm discussion that you linked and understand their design better.

It seems that what you are proposing depends on whether the operands' arith.assume definitions are always reachable from arith.select, but it may not be true with block arguments. So having arith.select to carry fast-math seems more robust to me.

I'd think that the real implementation can have a matcher that checks if arith.select operands are <insert-your-favorite-fmf-assumption>, and frontends can insert these assumef ops to communicate this in the absence of anything else that would imply it.

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 4, 2025

Sorry, I was mostly thinking out loud and did not mean to derail this PR. I'm trying to understand the goal stated in the PR description:

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types and my goal to add fast-math support for arith.select operation that may produce results of any type.

Thanks for the explanation!

I'd think that the real implementation can have a matcher that checks if arith.select operands are <insert-your-favorite-fmf-assumption>, and frontends can insert these assumef ops to communicate this in the absence of anything else that would imply it.

I am not against the assume operations, but what I am trying to say is that MLIR region simplification may turn the operands of arith.select into block arguments, and so the matching will not be as easy as looking at the operands, it will be more as looking at all the possible values of the argument blocks and merging the fast-math flags meaning across all these values. At the same time, some attribution of the block arguments with fast-math may be needed, anyway, if we decide to follow LLVM's approach and mark the PHIs with fast-math.

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to `none`
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for `arith.select` operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.
@vzakhari vzakhari force-pushed the conditional_fast_math branch from 8834e36 to a517b83 Compare February 4, 2025 23:10
return true;

// TODO: what about TupleType and custom dialect struct-like types?
// It seems that they worth an interface to get to the list of element types.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a use case of an operation with the fastmath interface that takes/returns such type?

If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.

That said, such type interface would make some sense to me in general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly thinking about arith.select, but it seems it will be a separate discussion and changes (if any).

If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.

This might be an acceptable approach. I was thinking about making the other dialects' life easier by handling it here, but I can postpone the decision until the need arises.

/// TODO: the results often have the same type, and traversing
/// the same type again and again is not very efficient.
/// We can cache it here for the duration of the processing.
/// Other ideas?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about the result of a same operation with multiple results, or the result of different operation.

If this is about the former, it seems to ne the caching would be overkill given the average number or results in operation. If this is about the later, it seems to me that maintaining some shared cache somewhere would not be cheap, is the call to isCompatibleType that expensive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about the results of the same operation, especially, for the case of struct-like types that may have nested types... I agree with you that in the current state of isCompatibleType it does not make sence to do any caching. I will remove the comment.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vzakhari, the flang dialect changes looks good and this makes sense to me in general. Please wait for the approval from those who had comments.

return isCompatibleType(shapedType.getElementType());

// ComplexType's element type is always a FloatType.
if (auto complexType = dyn_cast<ComplexType>(type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move in isa with FloatType?

/// TODO: the results often have the same type, and traversing
/// the same type again and again is not very efficient.
/// We can cache it here for the duration of the processing.
/// Other ideas?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about the result of a same operation with multiple results, or the result of different operation.

If this is about the former, it seems to ne the caching would be overkill given the average number or results in operation. If this is about the later, it seems to me that maintaining some shared cache somewhere would not be cheap, is the call to isCompatibleType that expensive?

return true;

// TODO: what about TupleType and custom dialect struct-like types?
// It seems that they worth an interface to get to the list of element types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a use case of an operation with the fastmath interface that takes/returns such type?

If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.

That said, such type interface would make some sense to me in general.

@benvanik
Copy link
Contributor

benvanik commented Feb 7, 2025

This feels very odd to me - like an implementation detail of LLVM leaking way higher up into the stack than it should. arith.select is a conditional move op and should never change the value of either operand. Ops like arith.select and block arguments are changed repeatedly throughout many pipelines and preserving (or even knowing how to set) fast math information is not feasible (or useful). If lowering into LLVM needs this information then it should perform analysis to determine the required information - it is much easier to do so at the time of lowering than it is to keep that information present and valid through layers as high up as tensor graph programming all the way down to LLVM that may be passing through half a dozen intermediate dialects.

@vzakhari
Copy link
Contributor Author

vzakhari commented Feb 7, 2025

arith.select is a conditional move op and should never change the value of either operand.

I am not sure I understand how fast-math attributes attached to arith.select imply that it can change the value of either operand - can you please clarify?

Ops like arith.select and block arguments are changed repeatedly throughout many pipelines and preserving (or even knowing how to set) fast math information is not feasible (or useful).

Isn't it true for any other discardable attribute? In my opinion, the transformations should drop any attributes that they don't know about or preserve any attribute that they know about (generic attribute interfaces may ease the handling here). Otherwise, how we can guarantee that a discardable attribute propagated during op->op transformation, without paying attention to preserving the attribute semantics, is still valid in the MLIR after the transformation?

I created https://discourse.llvm.org/t/rfc-arithfastmathinterface-support-for-arith-select/84508 for further discussion of the fast-math flags on arith.select. Can we please continue there?

I think the discussion here should be about conditional support of fast-math attributes on operations that support the fast-math interface. I believe the HLFIR case is a good example where such support makes sence.

If lowering into LLVM needs this information then it should perform analysis to determine the required information

I do not think it is always possible (see example in the above discourse RFC).

Copy link
Contributor Author

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments, Jean. I will upload updated files shortly.

/// TODO: the results often have the same type, and traversing
/// the same type again and again is not very efficient.
/// We can cache it here for the duration of the processing.
/// Other ideas?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about the results of the same operation, especially, for the case of struct-like types that may have nested types... I agree with you that in the current state of isCompatibleType it does not make sence to do any caching. I will remove the comment.

return true;

// TODO: what about TupleType and custom dialect struct-like types?
// It seems that they worth an interface to get to the list of element types.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly thinking about arith.select, but it seems it will be a separate discussion and changes (if any).

If these operations are not common, maybe it is best/cheaper for the rest of the usages to keep the logic here simple and have these operation do the type visit as needed like you did in HLFIR.

This might be an acceptable approach. I was thinking about making the other dialects' life easier by handling it here, but I can postpone the decision until the need arises.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants