Skip to content

Commit 2e1833f

Browse files
[mlir][IR] Remove builder API + caching for low-precision FP types
1 parent d7e48fb commit 2e1833f

File tree

7 files changed

+60
-158
lines changed

7 files changed

+60
-158
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,6 @@ class Builder {
6161
Attribute metadata = Attribute());
6262

6363
// Types.
64-
FloatType getFloat4E2M1FNType();
65-
FloatType getFloat6E2M3FNType();
66-
FloatType getFloat6E3M2FNType();
67-
FloatType getFloat8E5M2Type();
68-
FloatType getFloat8E4M3Type();
69-
FloatType getFloat8E4M3FNType();
70-
FloatType getFloat8E5M2FNUZType();
71-
FloatType getFloat8E4M3FNUZType();
72-
FloatType getFloat8E4M3B11FNUZType();
73-
FloatType getFloat8E3M4Type();
74-
FloatType getFloat8E8M0FNUType();
7564
FloatType getBF16Type();
7665
FloatType getF16Type();
7766
FloatType getTF32Type();

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
8585
DeclareTypeInterfaceMethods<
8686
FloatTypeInterface,
8787
["getFloatSemantics"] # declaredInterfaceMethods>]> {
88+
}
89+
90+
// Float types that are cached in MLIRContext.
91+
class Builtin_CachedFloatType<string name, string mnemonic,
92+
list<string> declaredInterfaceMethods = []>
93+
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
8894
let extraClassDeclaration = [{
8995
static }] # name # [{Type get(MLIRContext *context);
9096
}];
@@ -326,52 +332,52 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
326332
//===----------------------------------------------------------------------===//
327333
// BFloat16Type
328334

329-
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
335+
def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
330336
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
331337
let summary = "bfloat16 floating-point type";
332338
}
333339

334340
//===----------------------------------------------------------------------===//
335341
// Float16Type
336342

337-
def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
343+
def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
338344
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
339345
let summary = "16-bit floating-point type";
340346
}
341347

342348
//===----------------------------------------------------------------------===//
343349
// FloatTF32Type
344350

345-
def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
351+
def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
346352
let summary = "TF32 floating-point type";
347353
}
348354

349355
//===----------------------------------------------------------------------===//
350356
// Float32Type
351357

352-
def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
358+
def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
353359
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
354360
let summary = "32-bit floating-point type";
355361
}
356362

357363
//===----------------------------------------------------------------------===//
358364
// Float64Type
359365

360-
def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
366+
def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
361367
let summary = "64-bit floating-point type";
362368
}
363369

364370
//===----------------------------------------------------------------------===//
365371
// Float80Type
366372

367-
def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
373+
def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
368374
let summary = "80-bit floating-point type";
369375
}
370376

371377
//===----------------------------------------------------------------------===//
372378
// Float128Type
373379

374-
def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
380+
def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
375381
let summary = "128-bit floating-point type";
376382
}
377383

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,31 +330,31 @@ def F80 : F<80>;
330330
def F128 : F<128>;
331331

332332
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
333-
BuildableType<"$_builder.getBF16Type()">;
333+
BuildableType<"$_builder.getType<BFloat16Type>()">;
334334
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
335-
BuildableType<"$_builder.getTF32Type()">;
335+
BuildableType<"$_builder.getType<FloatTF32Type>()">;
336336
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
337-
BuildableType<"$_builder.getFloat8E4M3FNType()">;
337+
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
338338
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
339-
BuildableType<"$_builder.getFloat8E5M2Type()">;
339+
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
340340
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
341-
BuildableType<"$_builder.getFloat8E4M3Type()">;
341+
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
342342
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
343-
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
343+
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
344344
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
345-
BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
345+
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
346346
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
347-
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
347+
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
348348
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
349-
BuildableType<"$_builder.getFloat8E3M4Type()">;
349+
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
350350
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
351-
BuildableType<"$_builder.getFloat4E2M1FNType()">;
351+
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
352352
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
353-
BuildableType<"$_builder.getFloat6E2M3FNType()">;
353+
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
354354
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
355-
BuildableType<"$_builder.getFloat6E3M2FNType()">;
355+
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
356356
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
357-
BuildableType<"$_builder.getFloat8E8M0FNUType()">;
357+
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
358358

359359
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
360360
"complex-type", "::mlir::ComplexType">;

mlir/lib/AsmParser/TypeParser.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
309309
// float-type
310310
case Token::kw_f4E2M1FN:
311311
consumeToken(Token::kw_f4E2M1FN);
312-
return builder.getFloat4E2M1FNType();
312+
return builder.getType<Float4E2M1FNType>();
313313
case Token::kw_f6E2M3FN:
314314
consumeToken(Token::kw_f6E2M3FN);
315-
return builder.getFloat6E2M3FNType();
315+
return builder.getType<Float6E2M3FNType>();
316316
case Token::kw_f6E3M2FN:
317317
consumeToken(Token::kw_f6E3M2FN);
318-
return builder.getFloat6E3M2FNType();
318+
return builder.getType<Float6E3M2FNType>();
319319
case Token::kw_f8E5M2:
320320
consumeToken(Token::kw_f8E5M2);
321-
return builder.getFloat8E5M2Type();
321+
return builder.getType<Float8E5M2Type>();
322322
case Token::kw_f8E4M3:
323323
consumeToken(Token::kw_f8E4M3);
324-
return builder.getFloat8E4M3Type();
324+
return builder.getType<Float8E4M3Type>();
325325
case Token::kw_f8E4M3FN:
326326
consumeToken(Token::kw_f8E4M3FN);
327-
return builder.getFloat8E4M3FNType();
327+
return builder.getType<Float8E4M3FNType>();
328328
case Token::kw_f8E5M2FNUZ:
329329
consumeToken(Token::kw_f8E5M2FNUZ);
330-
return builder.getFloat8E5M2FNUZType();
330+
return builder.getType<Float8E5M2FNUZType>();
331331
case Token::kw_f8E4M3FNUZ:
332332
consumeToken(Token::kw_f8E4M3FNUZ);
333-
return builder.getFloat8E4M3FNUZType();
333+
return builder.getType<Float8E4M3FNUZType>();
334334
case Token::kw_f8E4M3B11FNUZ:
335335
consumeToken(Token::kw_f8E4M3B11FNUZ);
336-
return builder.getFloat8E4M3B11FNUZType();
336+
return builder.getType<Float8E4M3B11FNUZType>();
337337
case Token::kw_f8E3M4:
338338
consumeToken(Token::kw_f8E3M4);
339-
return builder.getFloat8E3M4Type();
339+
return builder.getType<Float8E3M4Type>();
340340
case Token::kw_f8E8M0FNU:
341341
consumeToken(Token::kw_f8E8M0FNU);
342-
return builder.getFloat8E8M0FNUType();
342+
return builder.getType<Float8E8M0FNUType>();
343343
case Token::kw_bf16:
344344
consumeToken(Token::kw_bf16);
345-
return builder.getBF16Type();
345+
return builder.getType<BFloat16Type>();
346346
case Token::kw_f16:
347347
consumeToken(Token::kw_f16);
348-
return builder.getF16Type();
348+
return builder.getType<Float16Type>();
349349
case Token::kw_tf32:
350350
consumeToken(Token::kw_tf32);
351-
return builder.getTF32Type();
351+
return builder.getType<FloatTF32Type>();
352352
case Token::kw_f32:
353353
consumeToken(Token::kw_f32);
354-
return builder.getF32Type();
354+
return builder.getType<Float32Type>();
355355
case Token::kw_f64:
356356
consumeToken(Token::kw_f64);
357-
return builder.getF64Type();
357+
return builder.getType<Float64Type>();
358358
case Token::kw_f80:
359359
consumeToken(Token::kw_f80);
360-
return builder.getF80Type();
360+
return builder.getType<Float80Type>();
361361
case Token::kw_f128:
362362
consumeToken(Token::kw_f128);
363-
return builder.getF128Type();
363+
return builder.getType<Float128Type>();
364364

365365
// index-type
366366
case Token::kw_index:

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
361361
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362362
Builder b(ctx);
363363
return llvm::StringSwitch<std::optional<FloatType>>(name)
364-
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
365-
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
366-
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
367-
.Case("f8E5M2", b.getFloat8E5M2Type())
368-
.Case("f8E4M3", b.getFloat8E4M3Type())
369-
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
370-
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
371-
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
372-
.Case("f8E3M4", b.getFloat8E3M4Type())
373-
.Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
374-
.Case("bf16", b.getBF16Type())
375-
.Case("f16", b.getF16Type())
376-
.Case("f32", b.getF32Type())
377-
.Case("f64", b.getF64Type())
378-
.Case("f80", b.getF80Type())
379-
.Case("f128", b.getF128Type())
364+
.Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
365+
.Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
366+
.Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
367+
.Case("f8E5M2", b.getType<Float8E5M2Type>())
368+
.Case("f8E4M3", b.getType<Float8E4M3Type>())
369+
.Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
370+
.Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
371+
.Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
372+
.Case("f8E3M4", b.getType<Float8E3M4Type>())
373+
.Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
374+
.Case("bf16", b.getType<BFloat16Type>())
375+
.Case("f16", b.getType<Float16Type>())
376+
.Case("f32", b.getType<Float32Type>())
377+
.Case("f64", b.getType<Float64Type>())
378+
.Case("f80", b.getType<Float80Type>())
379+
.Case("f128", b.getType<Float128Type>())
380380
.Default(std::nullopt);
381381
}
382382

mlir/lib/IR/Builders.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
3434
// Types.
3535
//===----------------------------------------------------------------------===//
3636

37-
FloatType Builder::getFloat4E2M1FNType() {
38-
return Float4E2M1FNType::get(context);
39-
}
40-
41-
FloatType Builder::getFloat6E2M3FNType() {
42-
return Float6E2M3FNType::get(context);
43-
}
44-
45-
FloatType Builder::getFloat6E3M2FNType() {
46-
return Float6E3M2FNType::get(context);
47-
}
48-
49-
FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }
50-
51-
FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }
52-
53-
FloatType Builder::getFloat8E4M3FNType() {
54-
return Float8E4M3FNType::get(context);
55-
}
56-
57-
FloatType Builder::getFloat8E5M2FNUZType() {
58-
return Float8E5M2FNUZType::get(context);
59-
}
60-
61-
FloatType Builder::getFloat8E4M3FNUZType() {
62-
return Float8E4M3FNUZType::get(context);
63-
}
64-
65-
FloatType Builder::getFloat8E4M3B11FNUZType() {
66-
return Float8E4M3B11FNUZType::get(context);
67-
}
68-
69-
FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }
70-
71-
FloatType Builder::getFloat8E8M0FNUType() {
72-
return Float8E8M0FNUType::get(context);
73-
}
74-
7537
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
7638

7739
FloatType Builder::getF16Type() { return Float16Type::get(context); }

mlir/lib/IR/MLIRContext.cpp

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,6 @@ class MLIRContextImpl {
221221
llvm::DenseMap<StringRef, AbstractType *> nameToType;
222222

223223
/// Cached Type Instances.
224-
Float4E2M1FNType f4E2M1FNTy;
225-
Float6E2M3FNType f6E2M3FNTy;
226-
Float6E3M2FNType f6E3M2FNTy;
227-
Float8E5M2Type f8E5M2Ty;
228-
Float8E4M3Type f8E4M3Ty;
229-
Float8E4M3FNType f8E4M3FNTy;
230-
Float8E5M2FNUZType f8E5M2FNUZTy;
231-
Float8E4M3FNUZType f8E4M3FNUZTy;
232-
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
233-
Float8E3M4Type f8E3M4Ty;
234-
Float8E8M0FNUType f8E8M0FNUTy;
235224
BFloat16Type bf16Ty;
236225
Float16Type f16Ty;
237226
FloatTF32Type tf32Ty;
@@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
317306

318307
//// Types.
319308
/// Floating-point Types.
320-
impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
321-
impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
322-
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
323-
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
324-
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
325-
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
326-
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
327-
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
328-
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
329-
impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
330-
impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
331309
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
332310
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
333311
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
10441022
/// This should not be used directly.
10451023
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
10461024

1047-
Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
1048-
return context->getImpl().f4E2M1FNTy;
1049-
}
1050-
Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
1051-
return context->getImpl().f6E2M3FNTy;
1052-
}
1053-
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
1054-
return context->getImpl().f6E3M2FNTy;
1055-
}
1056-
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
1057-
return context->getImpl().f8E5M2Ty;
1058-
}
1059-
Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
1060-
return context->getImpl().f8E4M3Ty;
1061-
}
1062-
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
1063-
return context->getImpl().f8E4M3FNTy;
1064-
}
1065-
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
1066-
return context->getImpl().f8E5M2FNUZTy;
1067-
}
1068-
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
1069-
return context->getImpl().f8E4M3FNUZTy;
1070-
}
1071-
Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
1072-
return context->getImpl().f8E4M3B11FNUZTy;
1073-
}
1074-
Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
1075-
return context->getImpl().f8E3M4Ty;
1076-
}
1077-
Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
1078-
return context->getImpl().f8E8M0FNUTy;
1079-
}
10801025
BFloat16Type BFloat16Type::get(MLIRContext *context) {
10811026
return context->getImpl().bf16Ty;
10821027
}

0 commit comments

Comments
 (0)