Skip to content

Commit c24ce32

Browse files
[mlir][IR] Turn FloatType into a type interface (#118891)
This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.) Also removes two functions where we had to hard-code all existing floating point types (`FloatType::classof`). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361 No measurable compilation time changes for these lit tests: ``` Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null BEFORE Time (mean ± σ): 248.4 ms ± 3.2 ms [User: 237.0 ms, System: 20.1 ms] Range (min … max): 243.3 ms … 255.9 ms 30 runs AFTER Time (mean ± σ): 246.8 ms ± 3.2 ms [User: 233.2 ms, System: 21.8 ms] Range (min … max): 240.2 ms … 252.1 ms 30 runs Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null BEFORE Time (mean ± σ): 37.3 ms ± 1.8 ms [User: 31.6 ms, System: 30.4 ms] Range (min … max): 34.6 ms … 42.0 ms 200 runs AFTER Time (mean ± σ): 37.5 ms ± 2.0 ms [User: 31.5 ms, System: 29.2 ms] Range (min … max): 34.5 ms … 43.0 ms 200 runs Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null BEFORE Time (mean ± σ): 152.2 ms ± 2.5 ms [User: 140.1 ms, System: 12.2 ms] Range (min … max): 147.6 ms … 161.8 ms 200 runs AFTER Time (mean ± σ): 151.9 ms ± 2.7 ms [User: 140.5 ms, System: 11.5 ms] Range (min … max): 147.2 ms … 159.1 ms 200 runs ``` A micro benchmark that parses + prints 32768 floats with random floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
1 parent 4c2e4ea commit c24ce32

File tree

7 files changed

+138
-124
lines changed

7 files changed

+138
-124
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
#include "mlir/IR/Types.h"
1313

14+
namespace llvm {
15+
struct fltSemantics;
16+
} // namespace llvm
17+
18+
namespace mlir {
19+
class FloatType;
20+
class MLIRContext;
21+
} // namespace mlir
22+
1423
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
1524

1625
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,65 @@
1616

1717
include "mlir/IR/OpBase.td"
1818

19+
def FloatTypeInterface : TypeInterface<"FloatType"> {
20+
let cppNamespace = "::mlir";
21+
let description = [{
22+
This type interface should be implemented by all floating-point types. It
23+
defines the LLVM APFloat semantics and provides a few helper functions.
24+
}];
25+
26+
let methods = [
27+
InterfaceMethod<
28+
/*desc=*/[{
29+
Returns the APFloat semantics for this floating-point type.
30+
}],
31+
/*retTy=*/"const ::llvm::fltSemantics &",
32+
/*methodName=*/"getFloatSemantics",
33+
/*args=*/(ins)
34+
>,
35+
InterfaceMethod<
36+
/*desc=*/[{
37+
Returns a float type with bitwidth scaled by `scale`. Returns a "null"
38+
float type if the scaled element type cannot be represented.
39+
}],
40+
/*retTy=*/"::mlir::FloatType",
41+
/*methodName=*/"scaleElementBitwidth",
42+
/*args=*/(ins "unsigned":$scale),
43+
/*methodBody=*/"",
44+
/*defaultImplementation=*/"return ::mlir::FloatType();"
45+
>
46+
];
47+
48+
let extraClassDeclaration = [{
49+
// Convenience factories.
50+
static FloatType getBF16(MLIRContext *ctx);
51+
static FloatType getF16(MLIRContext *ctx);
52+
static FloatType getF32(MLIRContext *ctx);
53+
static FloatType getTF32(MLIRContext *ctx);
54+
static FloatType getF64(MLIRContext *ctx);
55+
static FloatType getF80(MLIRContext *ctx);
56+
static FloatType getF128(MLIRContext *ctx);
57+
static FloatType getFloat8E5M2(MLIRContext *ctx);
58+
static FloatType getFloat8E4M3(MLIRContext *ctx);
59+
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
60+
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
61+
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
62+
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
63+
static FloatType getFloat8E3M4(MLIRContext *ctx);
64+
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
65+
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
66+
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
67+
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
68+
69+
/// Return the bitwidth of this float type.
70+
unsigned getWidth();
71+
72+
/// Return the width of the mantissa of this type.
73+
/// The width includes the integer bit.
74+
unsigned getFPMantissaWidth();
75+
}];
76+
}
77+
1978
//===----------------------------------------------------------------------===//
2079
// MemRefElementTypeInterface
2180
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ struct fltSemantics;
2525
namespace mlir {
2626
class AffineExpr;
2727
class AffineMap;
28-
class FloatType;
2928
class IndexType;
3029
class IntegerType;
3130
class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
4443
class ValueSemantics
4544
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4645

47-
//===----------------------------------------------------------------------===//
48-
// FloatType
49-
//===----------------------------------------------------------------------===//
50-
51-
class FloatType : public Type {
52-
public:
53-
using Type::Type;
54-
55-
// Convenience factories.
56-
static FloatType getBF16(MLIRContext *ctx);
57-
static FloatType getF16(MLIRContext *ctx);
58-
static FloatType getF32(MLIRContext *ctx);
59-
static FloatType getTF32(MLIRContext *ctx);
60-
static FloatType getF64(MLIRContext *ctx);
61-
static FloatType getF80(MLIRContext *ctx);
62-
static FloatType getF128(MLIRContext *ctx);
63-
static FloatType getFloat8E5M2(MLIRContext *ctx);
64-
static FloatType getFloat8E4M3(MLIRContext *ctx);
65-
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
66-
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
67-
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
68-
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
69-
static FloatType getFloat8E3M4(MLIRContext *ctx);
70-
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
71-
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
72-
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
73-
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
74-
75-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
76-
static bool classof(Type type);
77-
78-
/// Return the bitwidth of this float type.
79-
unsigned getWidth();
80-
81-
/// Return the width of the mantissa of this type.
82-
/// The width includes the integer bit.
83-
unsigned getFPMantissaWidth();
84-
85-
/// Get or create a new FloatType with bitwidth scaled by `scale`.
86-
/// Return null if the scaled element type cannot be represented.
87-
FloatType scaleElementBitwidth(unsigned scale);
88-
89-
/// Return the floating semantics of this float type.
90-
const llvm::fltSemantics &getFloatSemantics();
91-
};
92-
9346
//===----------------------------------------------------------------------===//
9447
// TensorType
9548
//===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
448401
llvm::isa<MemRefElementTypeInterface>(type);
449402
}
450403

451-
inline bool FloatType::classof(Type type) {
452-
return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
453-
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
454-
Float8E5M2FNUZType, Float8E4M3FNUZType,
455-
Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
456-
BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
457-
Float64Type, Float80Type, Float128Type>(type);
458-
}
459-
460404
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
461405
return Float4E2M1FNType::get(ctx);
462406
}

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,12 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
7979
//===----------------------------------------------------------------------===//
8080

8181
// Base class for Builtin dialect float types.
82-
class Builtin_FloatType<string name, string mnemonic>
83-
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
82+
class Builtin_FloatType<string name, string mnemonic,
83+
list<string> declaredInterfaceMethods = []>
84+
: Builtin_Type<name, mnemonic, /*traits=*/[
85+
DeclareTypeInterfaceMethods<
86+
FloatTypeInterface,
87+
["getFloatSemantics"] # declaredInterfaceMethods>]> {
8488
let extraClassDeclaration = [{
8589
static }] # name # [{Type get(MLIRContext *context);
8690
}];
@@ -322,14 +326,16 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
322326
//===----------------------------------------------------------------------===//
323327
// BFloat16Type
324328

325-
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> {
329+
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
330+
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
326331
let summary = "bfloat16 floating-point type";
327332
}
328333

329334
//===----------------------------------------------------------------------===//
330335
// Float16Type
331336

332-
def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> {
337+
def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
338+
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
333339
let summary = "16-bit floating-point type";
334340
}
335341

@@ -343,7 +349,8 @@ def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
343349
//===----------------------------------------------------------------------===//
344350
// Float32Type
345351

346-
def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> {
352+
def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
353+
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
347354
let summary = "32-bit floating-point type";
348355
}
349356

mlir/lib/IR/BuiltinTypeInterfaces.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/IR/BuiltinTypes.h"
1010
#include "mlir/IR/Diagnostics.h"
11+
#include "llvm/ADT/APFloat.h"
1112
#include "llvm/ADT/Sequence.h"
1213

1314
using namespace mlir;
@@ -19,6 +20,18 @@ using namespace mlir::detail;
1920

2021
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
2122

23+
//===----------------------------------------------------------------------===//
24+
// FloatType
25+
//===----------------------------------------------------------------------===//
26+
27+
unsigned FloatType::getWidth() {
28+
return APFloat::semanticsSizeInBits(getFloatSemantics());
29+
}
30+
31+
unsigned FloatType::getFPMantissaWidth() {
32+
return APFloat::semanticsPrecision(getFloatSemantics());
33+
}
34+
2235
//===----------------------------------------------------------------------===//
2336
// ShapedType
2437
//===----------------------------------------------------------------------===//

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -87,72 +87,54 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
8787
}
8888

8989
//===----------------------------------------------------------------------===//
90-
// Float Type
91-
//===----------------------------------------------------------------------===//
92-
93-
unsigned FloatType::getWidth() {
94-
return APFloat::semanticsSizeInBits(getFloatSemantics());
95-
}
96-
97-
/// Returns the floating semantics for the given type.
98-
const llvm::fltSemantics &FloatType::getFloatSemantics() {
99-
if (llvm::isa<Float4E2M1FNType>(*this))
100-
return APFloat::Float4E2M1FN();
101-
if (llvm::isa<Float6E2M3FNType>(*this))
102-
return APFloat::Float6E2M3FN();
103-
if (llvm::isa<Float6E3M2FNType>(*this))
104-
return APFloat::Float6E3M2FN();
105-
if (llvm::isa<Float8E5M2Type>(*this))
106-
return APFloat::Float8E5M2();
107-
if (llvm::isa<Float8E4M3Type>(*this))
108-
return APFloat::Float8E4M3();
109-
if (llvm::isa<Float8E4M3FNType>(*this))
110-
return APFloat::Float8E4M3FN();
111-
if (llvm::isa<Float8E5M2FNUZType>(*this))
112-
return APFloat::Float8E5M2FNUZ();
113-
if (llvm::isa<Float8E4M3FNUZType>(*this))
114-
return APFloat::Float8E4M3FNUZ();
115-
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
116-
return APFloat::Float8E4M3B11FNUZ();
117-
if (llvm::isa<Float8E3M4Type>(*this))
118-
return APFloat::Float8E3M4();
119-
if (llvm::isa<Float8E8M0FNUType>(*this))
120-
return APFloat::Float8E8M0FNU();
121-
if (llvm::isa<BFloat16Type>(*this))
122-
return APFloat::BFloat();
123-
if (llvm::isa<Float16Type>(*this))
124-
return APFloat::IEEEhalf();
125-
if (llvm::isa<FloatTF32Type>(*this))
126-
return APFloat::FloatTF32();
127-
if (llvm::isa<Float32Type>(*this))
128-
return APFloat::IEEEsingle();
129-
if (llvm::isa<Float64Type>(*this))
130-
return APFloat::IEEEdouble();
131-
if (llvm::isa<Float80Type>(*this))
132-
return APFloat::x87DoubleExtended();
133-
if (llvm::isa<Float128Type>(*this))
134-
return APFloat::IEEEquad();
135-
llvm_unreachable("non-floating point type used");
136-
}
137-
138-
FloatType FloatType::scaleElementBitwidth(unsigned scale) {
139-
if (!scale)
140-
return FloatType();
141-
MLIRContext *ctx = getContext();
142-
if (isF16() || isBF16()) {
143-
if (scale == 2)
144-
return FloatType::getF32(ctx);
145-
if (scale == 4)
146-
return FloatType::getF64(ctx);
90+
// Float Types
91+
//===----------------------------------------------------------------------===//
92+
93+
// Mapping from MLIR FloatType to APFloat semantics.
94+
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
95+
const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
96+
return APFloat::SEM(); \
14797
}
148-
if (isF32())
149-
if (scale == 2)
150-
return FloatType::getF64(ctx);
98+
FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
99+
FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
100+
FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
101+
FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
102+
FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
103+
FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
104+
FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
105+
FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
106+
FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
107+
FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
108+
FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
109+
FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
110+
FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
111+
FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
112+
FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
113+
FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
114+
FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
115+
FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
116+
#undef FLOAT_TYPE_SEMANTICS
117+
118+
FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
119+
if (scale == 2)
120+
return FloatType::getF32(getContext());
121+
if (scale == 4)
122+
return FloatType::getF64(getContext());
151123
return FloatType();
152124
}
153125

154-
unsigned FloatType::getFPMantissaWidth() {
155-
return APFloat::semanticsPrecision(getFloatSemantics());
126+
FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
127+
if (scale == 2)
128+
return FloatType::getF32(getContext());
129+
if (scale == 4)
130+
return FloatType::getF64(getContext());
131+
return FloatType();
132+
}
133+
134+
FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
135+
if (scale == 2)
136+
return FloatType::getF64(getContext());
137+
return FloatType();
156138
}
157139

158140
//===----------------------------------------------------------------------===//

mlir/unittests/IR/InterfaceAttachmentTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct Model
4343
/// overrides default methods.
4444
struct OverridingModel
4545
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
46-
FloatType> {
46+
Float32Type> {
4747
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
4848
return type.getIntOrFloatBitWidth() + arg;
4949
}

0 commit comments

Comments
 (0)