Skip to content

Commit 6685fd8

Browse files
jfurtekjoker-eph
authored andcommitted
[mlir] Add support for TF32 as a Builtin FloatType
This diff adds support for TF32 as a Builtin floating point type. This supplements the recent addition of the TF32 semantic to the LLVM APFloat class by extending usage to MLIR. https://reviews.llvm.org/D151923 More information on the TF32 type can be found here: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D153705
1 parent 8f7e41d commit 6685fd8

File tree

20 files changed

+93
-3
lines changed

20 files changed

+93
-3
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
163163
/// context.
164164
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
165165

166+
/// Returns the typeID of a TF32 type.
167+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void);
168+
169+
/// Checks whether the given type is an TF32 type.
170+
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type);
171+
172+
/// Creates a TF32 type in the given context. The type is owned by the
173+
/// context.
174+
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx);
175+
166176
//===----------------------------------------------------------------------===//
167177
// None type.
168178
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class Builder {
6767
FloatType getFloat8E4M3B11FNUZType();
6868
FloatType getBF16Type();
6969
FloatType getF16Type();
70+
FloatType getTF32Type();
7071
FloatType getF32Type();
7172
FloatType getF64Type();
7273
FloatType getF80Type();

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class FloatType : public Type {
4444
static FloatType getBF16(MLIRContext *ctx);
4545
static FloatType getF16(MLIRContext *ctx);
4646
static FloatType getF32(MLIRContext *ctx);
47+
static FloatType getTF32(MLIRContext *ctx);
4748
static FloatType getF64(MLIRContext *ctx);
4849
static FloatType getF80(MLIRContext *ctx);
4950
static FloatType getF128(MLIRContext *ctx);
@@ -417,8 +418,8 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
417418
inline bool FloatType::classof(Type type) {
418419
return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
419420
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
420-
Float16Type, Float32Type, Float64Type, Float80Type,
421-
Float128Type>(type);
421+
Float16Type, FloatTF32Type, Float32Type, Float64Type,
422+
Float80Type, Float128Type>(type);
422423
}
423424

424425
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -449,6 +450,10 @@ inline FloatType FloatType::getF16(MLIRContext *ctx) {
449450
return Float16Type::get(ctx);
450451
}
451452

453+
inline FloatType FloatType::getTF32(MLIRContext *ctx) {
454+
return FloatTF32Type::get(ctx);
455+
}
456+
452457
inline FloatType FloatType::getF32(MLIRContext *ctx) {
453458
return Float32Type::get(ctx);
454459
}

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ def Builtin_Float16 : Builtin_FloatType<"Float16"> {
198198
let summary = "16-bit floating-point type";
199199
}
200200

201+
//===----------------------------------------------------------------------===//
202+
// FloatTF32Type
203+
204+
def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32"> {
205+
let summary = "TF32 floating-point type";
206+
}
207+
201208
//===----------------------------------------------------------------------===//
202209
// Float32Type
203210

mlir/include/mlir/IR/OpBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ def F128 : F<128>;
570570

571571
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
572572
BuildableType<"$_builder.getBF16Type()">;
573+
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
574+
BuildableType<"$_builder.getTF32Type()">;
573575
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
574576
BuildableType<"$_builder.getFloat8E4M3FNType()">;
575577
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,

mlir/include/mlir/IR/Types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class Type {
127127
bool isFloat8E4M3B11FNUZ() const;
128128
bool isBF16() const;
129129
bool isF16() const;
130+
bool isTF32() const;
130131
bool isF32() const;
131132
bool isF64() const;
132133
bool isF80() const;

mlir/lib/AsmParser/TokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ TOK_KEYWORD(step)
117117
TOK_KEYWORD(strided)
118118
TOK_KEYWORD(symbol)
119119
TOK_KEYWORD(tensor)
120+
TOK_KEYWORD(tf32)
120121
TOK_KEYWORD(to)
121122
TOK_KEYWORD(true)
122123
TOK_KEYWORD(tuple)

mlir/lib/AsmParser/TypeParser.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
3838
case Token::kw_f8E4M3B11FNUZ:
3939
case Token::kw_bf16:
4040
case Token::kw_f16:
41+
case Token::kw_tf32:
4142
case Token::kw_f32:
4243
case Token::kw_f64:
4344
case Token::kw_f80:
@@ -313,6 +314,9 @@ Type Parser::parseNonFunctionType() {
313314
case Token::kw_f16:
314315
consumeToken(Token::kw_f16);
315316
return builder.getF16Type();
317+
case Token::kw_tf32:
318+
consumeToken(Token::kw_tf32);
319+
return builder.getTF32Type();
316320
case Token::kw_f32:
317321
consumeToken(Token::kw_f32);
318322
return builder.getF32Type();

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,26 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
247247
}
248248
};
249249

250+
/// Floating Point Type subclass - TF32Type.
251+
class PyTF32Type : public PyConcreteType<PyTF32Type> {
252+
public:
253+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
254+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
255+
mlirFloatTF32TypeGetTypeID;
256+
static constexpr const char *pyClassName = "FloatTF32Type";
257+
using PyConcreteType::PyConcreteType;
258+
259+
static void bindDerived(ClassTy &c) {
260+
c.def_static(
261+
"get",
262+
[](DefaultingPyMlirContext context) {
263+
MlirType t = mlirTF32TypeGet(context->get());
264+
return PyTF32Type(context->getRef(), t);
265+
},
266+
py::arg("context") = py::none(), "Create a tf32 type.");
267+
}
268+
};
269+
250270
/// Floating Point Type subclass - F32Type.
251271
class PyF32Type : public PyConcreteType<PyF32Type> {
252272
public:
@@ -754,6 +774,7 @@ void mlir::python::populateIRTypes(py::module &m) {
754774
PyFloat8E5M2FNUZType::bind(m);
755775
PyBF16Type::bind(m);
756776
PyF16Type::bind(m);
777+
PyTF32Type::bind(m);
757778
PyF32Type::bind(m);
758779
PyF64Type::bind(m);
759780
PyNoneType::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ MlirType mlirF16TypeGet(MlirContext ctx) {
152152
return wrap(FloatType::getF16(unwrap(ctx)));
153153
}
154154

155+
MlirTypeID mlirFloatTF32TypeGetTypeID() {
156+
return wrap(FloatTF32Type::getTypeID());
157+
}
158+
159+
bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
160+
161+
MlirType mlirTF32TypeGet(MlirContext ctx) {
162+
return wrap(FloatType::getTF32(unwrap(ctx)));
163+
}
164+
155165
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
156166

157167
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }

0 commit comments

Comments
 (0)