Skip to content

Commit b56e65d

Browse files
jfurtekvzakhari
authored andcommitted
[mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2)
This diff adds initial (partial) support for "fastmath" attributes for floating point operations in the arithmetic dialect. The "fastmath" attributes are implemented using a default-valued bit enum. The defined flags currently mirror the fastmath flags in the LLVM dialect (and in LLVM itself). Extending the set of flags (if necessary) is left as a future task. In this diff: - Definition of FastMathAttr as a custom attribute in the Arithmetic dialect that inherits from the EnumAttr class. - Definition of ArithFastMathInterface, which is an interface that is implemented by operations that have an arith::fastmath attribute. - Declaration of a default-valued fastmath attribute for unary and (some) binary floating point operations in the Arithmetic dialect. - Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags NOT in this diff (but planned or currently in progress): - Documentation of flag meanings - Addition of FastMathAttr attributes to other dialects that might lower to the Arithmetic dialect (e.g. Math and Complex) - Folding/rewrite implementations that are enabled by fastmath flags - Specification of fastmath values from Python bindings (pending other in- progress diffs) Reviewed By: mehdi_amini, vzakhari Differential Revision: https://reviews.llvm.org/D126305
1 parent f6eb089 commit b56e65d

File tree

24 files changed

+372
-44
lines changed

24 files changed

+372
-44
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ namespace detail {
2222
/// and given operands.
2323
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
2424
ValueRange operands,
25+
ArrayRef<NamedAttribute> targetAttrs,
2526
LLVMTypeConverter &typeConverter,
2627
ConversionPatternRewriter &rewriter);
28+
2729
} // namespace detail
2830
} // namespace LLVM
2931

@@ -197,7 +199,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
197199
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
198200
ConversionPatternRewriter &rewriter) const override {
199201
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
200-
adaptor.getOperands(),
202+
adaptor.getOperands(), op->getAttrs(),
201203
*this->getTypeConverter(), rewriter);
202204
}
203205
};

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,34 @@ LogicalResult handleMultidimensionalVectors(
5656

5757
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
5858
ValueRange operands,
59+
ArrayRef<NamedAttribute> targetAttrs,
5960
LLVMTypeConverter &typeConverter,
6061
ConversionPatternRewriter &rewriter);
6162
} // namespace detail
6263
} // namespace LLVM
6364

65+
// Default attribute conversion class, which passes all source attributes
66+
// through to the target op, unmodified.
67+
template <typename SourceOp, typename TargetOp>
68+
class AttrConvertPassThrough {
69+
public:
70+
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
71+
72+
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
73+
74+
private:
75+
ArrayRef<NamedAttribute> srcAttrs;
76+
};
77+
6478
/// Basic lowering implementation to rewrite Ops with just one result to the
6579
/// LLVM Dialect. This supports higher-dimensional vector types.
66-
template <typename SourceOp, typename TargetOp>
80+
/// The AttrConvert template template parameter should be a template class
81+
/// with SourceOp and TargetOp type parameters, a constructor that takes
82+
/// a SourceOp instance, and a getAttrs() method that returns
83+
/// ArrayRef<NamedAttribute>.
84+
template <typename SourceOp, typename TargetOp,
85+
template <typename, typename> typename AttrConvert =
86+
AttrConvertPassThrough>
6787
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
6888
public:
6989
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -75,9 +95,12 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
7595
static_assert(
7696
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
7797
"expected single result op");
98+
// Determine attributes for the target op
99+
AttrConvert<SourceOp, TargetOp> attrConvert(op);
100+
78101
return LLVM::detail::vectorOneToOneRewrite(
79102
op, TargetOp::getOperationName(), adaptor.getOperands(),
80-
*this->getTypeConverter(), rewriter);
103+
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
81104
}
82105
};
83106
} // namespace mlir

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/InferTypeOpInterface.h"
1818
#include "mlir/Interfaces/SideEffectInterfaces.h"
1919
#include "mlir/Interfaces/VectorInterfaces.h"
20+
#include "llvm/ADT/StringExtras.h"
2021

2122
//===----------------------------------------------------------------------===//
2223
// ArithDialect
@@ -29,6 +30,13 @@
2930
//===----------------------------------------------------------------------===//
3031

3132
#include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
33+
#define GET_ATTRDEF_CLASSES
34+
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
35+
36+
//===----------------------------------------------------------------------===//
37+
// Arith Interfaces
38+
//===----------------------------------------------------------------------===//
39+
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
3240

3341
//===----------------------------------------------------------------------===//
3442
// Arith Dialect Operations

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def Arith_Dialect : Dialect {
2323
}];
2424

2525
let hasConstantMaterializer = 1;
26+
let useDefaultAttributePrinterParser = 1;
2627
}
2728

2829
// The predicate indicates the type of the comparison to perform:
@@ -92,4 +93,32 @@ def AtomicRMWKindAttr : I64EnumAttr<
9293
let cppNamespace = "::mlir::arith";
9394
}
9495

96+
def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >;
97+
def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>;
98+
def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>;
99+
def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>;
100+
def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>;
101+
def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>;
102+
def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>;
103+
def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>;
104+
def FASTMATH_FAST : I32BitEnumAttrCaseGroup<
105+
"fast",
106+
[
107+
FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS,
108+
FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT,
109+
FASTMATH_APPROX_FUNC]>;
110+
111+
def FastMathFlags : I32BitEnumAttr<
112+
"FastMathFlags",
113+
"Floating point fast math flags",
114+
[
115+
FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS,
116+
FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP,
117+
FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> {
118+
let separator = ",";
119+
let cppNamespace = "::mlir::arith";
120+
let genSpecializedAttr = 0;
121+
let printBitEnumPrimaryGroups = 1;
122+
}
123+
95124
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010
#define ARITH_OPS
1111

1212
include "mlir/Dialect/Arith/IR/ArithBase.td"
13+
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1314
include "mlir/Interfaces/CastInterfaces.td"
1415
include "mlir/Interfaces/InferIntRangeInterface.td"
1516
include "mlir/Interfaces/InferTypeOpInterface.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
1718
include "mlir/Interfaces/VectorInterfaces.td"
1819
include "mlir/IR/BuiltinAttributeInterfaces.td"
1920
include "mlir/IR/OpAsmInterface.td"
21+
include "mlir/IR/EnumAttr.td"
22+
23+
def Arith_FastMathAttr :
24+
EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
25+
let assemblyFormat = "`<` $value `>`";
26+
}
2027

2128
// Base class for Arith dialect ops. Ops in this dialect have no side
2229
// effects and can be applied element-wise to vectors and tensors.
@@ -58,15 +65,27 @@ class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
5865

5966
// Base class for floating point unary operations.
6067
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
61-
Arith_UnaryOp<mnemonic, traits>,
62-
Arguments<(ins FloatLike:$operand)>,
63-
Results<(outs FloatLike:$result)>;
68+
Arith_UnaryOp<mnemonic,
69+
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
70+
traits)>,
71+
Arguments<(ins FloatLike:$operand,
72+
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
73+
Results<(outs FloatLike:$result)> {
74+
let assemblyFormat = [{ $operand custom<ArithFastMathAttr>($fastmath)
75+
attr-dict `:` type($result) }];
76+
}
6477

6578
// Base class for floating point binary operations.
6679
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
67-
Arith_BinaryOp<mnemonic, traits>,
68-
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>,
69-
Results<(outs FloatLike:$result)>;
80+
Arith_BinaryOp<mnemonic,
81+
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
82+
traits)>,
83+
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
84+
DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
85+
Results<(outs FloatLike:$result)> {
86+
let assemblyFormat = [{ $lhs `,` $rhs `` custom<ArithFastMathAttr>($fastmath)
87+
attr-dict `:` type($result) }];
88+
}
7089

7190
// Base class for arithmetic cast operations. Requires a single operand and
7291
// result. If either is a shaped type, then the other must be of the same shape.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the Arith interfaces definition file.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef ARITH_OPS_INTERFACES
14+
#define ARITH_OPS_INTERFACES
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
19+
let description = [{
20+
Access to operation fastmath flags.
21+
}];
22+
23+
let cppNamespace = "::mlir::arith";
24+
25+
let methods = [
26+
InterfaceMethod<
27+
/*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
28+
/*returnType=*/ "FastMathFlagsAttr",
29+
/*methodName=*/ "getFastMathFlagsAttr",
30+
/*args=*/ (ins),
31+
/*methodBody=*/ [{}],
32+
/*defaultImpl=*/ [{
33+
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
34+
return op.getFastmathAttr();
35+
}]
36+
>,
37+
StaticInterfaceMethod<
38+
/*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
39+
for the operation}],
40+
/*returnType=*/ "StringRef",
41+
/*methodName=*/ "getFastMathAttrName",
42+
/*args=*/ (ins),
43+
/*methodBody=*/ [{}],
44+
/*defaultImpl=*/ [{
45+
return "fastmath";
46+
}]
47+
>
48+
49+
];
50+
}
51+
52+
#endif // ARITH_OPS_INTERFACES
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
set(LLVM_TARGET_DEFINITIONS ArithOps.td)
22
mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls)
33
mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs)
4+
mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls
5+
-attrdefs-dialect=arith)
6+
mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs
7+
-attrdefs-dialect=arith)
48
add_mlir_dialect(ArithOps arith)
59
add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc)
10+
11+
set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td)
12+
mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls)
13+
mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs)
14+
add_public_tablegen_target(MLIRArithOpsInterfacesIncGen)

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class SmartMutex;
4848
namespace mlir {
4949
namespace LLVM {
5050
class LLVMDialect;
51-
class LoopOptionsAttrBuilder;
5251

5352
namespace detail {
5453
struct LLVMTypeStorage;

mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,28 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
2323
let cppNamespace = "::mlir::LLVM";
2424

2525
let methods = [
26-
InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
27-
"getFastmathFlags">,
26+
InterfaceMethod<
27+
/*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
28+
/*returnType=*/ "FastmathFlagsAttr",
29+
/*methodName=*/ "getFastmathAttr",
30+
/*args=*/ (ins),
31+
/*methodBody=*/ [{}],
32+
/*defaultImpl=*/ [{
33+
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
34+
return op.getFastmathFlagsAttr();
35+
}]
36+
>,
37+
StaticInterfaceMethod<
38+
/*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
39+
for the operation}],
40+
/*returnType=*/ "StringRef",
41+
/*methodName=*/ "getFastmathAttrName",
42+
/*args=*/ (ins),
43+
/*methodBody=*/ [{}],
44+
/*defaultImpl=*/ [{
45+
return "fastmathFlags";
46+
}]
47+
>
2848
];
2949
}
3050

0 commit comments

Comments
 (0)