-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][IR] Remove builder API + caching for low-precision FP types #123321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Remove builder API + caching for low-precision FP types #123321
Conversation
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-arith Author: Matthias Springer (matthias-springer) ChangesRemove builder API (e.g., For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28 Note for LLVM integration: Use Full diff: https://github.com/llvm/llvm-project/pull/123321.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index daea2a23d6fbed..cd8d3ee0af72b0 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -61,17 +61,6 @@ class Builder {
Attribute metadata = Attribute());
// Types.
- FloatType getFloat4E2M1FNType();
- FloatType getFloat6E2M3FNType();
- FloatType getFloat6E3M2FNType();
- FloatType getFloat8E5M2Type();
- FloatType getFloat8E4M3Type();
- FloatType getFloat8E4M3FNType();
- FloatType getFloat8E5M2FNUZType();
- FloatType getFloat8E4M3FNUZType();
- FloatType getFloat8E4M3B11FNUZType();
- FloatType getFloat8E3M4Type();
- FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index fc50b28c09e41c..4f09d2e41e7ceb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+}
+
+// Float types that are cached in MLIRContext.
+class Builtin_CachedFloatType<string name, string mnemonic,
+ list<string> declaredInterfaceMethods = []>
+ : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
@@ -326,7 +332,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type
-def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
+def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}
@@ -334,7 +340,7 @@ def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
//===----------------------------------------------------------------------===//
// Float16Type
-def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
+def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}
@@ -342,14 +348,14 @@ def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
//===----------------------------------------------------------------------===//
// FloatTF32Type
-def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
+def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
let summary = "TF32 floating-point type";
}
//===----------------------------------------------------------------------===//
// Float32Type
-def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
+def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}
@@ -357,21 +363,21 @@ def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
//===----------------------------------------------------------------------===//
// Float64Type
-def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
+def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
let summary = "64-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float80Type
-def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
+def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
let summary = "80-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float128Type
-def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
+def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
let summary = "128-bit floating-point type";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b9f8c1ed19470d..6f52195c1d7c92 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -330,31 +330,31 @@ def F80 : F<80>;
def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
- BuildableType<"$_builder.getBF16Type()">;
+ BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
- BuildableType<"$_builder.getTF32Type()">;
+ BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
- BuildableType<"$_builder.getFloat8E4M3FNType()">;
+ BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
- BuildableType<"$_builder.getFloat8E5M2Type()">;
+ BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
- BuildableType<"$_builder.getFloat8E4M3Type()">;
+ BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
- BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
- BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
- BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+ BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
- BuildableType<"$_builder.getFloat8E3M4Type()">;
+ BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
- BuildableType<"$_builder.getFloat4E2M1FNType()">;
+ BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
- BuildableType<"$_builder.getFloat6E2M3FNType()">;
+ BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
- BuildableType<"$_builder.getFloat6E3M2FNType()">;
+ BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
- BuildableType<"$_builder.getFloat8E8M0FNUType()">;
+ BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index c614eb39b364be..21bb0ec3d0d515 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
- return builder.getFloat4E2M1FNType();
+ return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
- return builder.getFloat6E2M3FNType();
+ return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
- return builder.getFloat6E3M2FNType();
+ return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
- return builder.getFloat8E5M2Type();
+ return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
- return builder.getFloat8E4M3Type();
+ return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
- return builder.getFloat8E4M3FNType();
+ return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
- return builder.getFloat8E5M2FNUZType();
+ return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
- return builder.getFloat8E4M3FNUZType();
+ return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
- return builder.getFloat8E4M3B11FNUZType();
+ return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
- return builder.getFloat8E3M4Type();
+ return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
- return builder.getFloat8E8M0FNUType();
+ return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
- return builder.getBF16Type();
+ return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
- return builder.getF16Type();
+ return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
- return builder.getTF32Type();
+ return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
- return builder.getF32Type();
+ return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
- return builder.getF64Type();
+ return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
- return builder.getF80Type();
+ return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
- return builder.getF128Type();
+ return builder.getType<Float128Type>();
// index-type
case Token::kw_index:
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0fa7d321844113..39c9005e449e38 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
- .Case("f4E2M1FN", b.getFloat4E2M1FNType())
- .Case("f6E2M3FN", b.getFloat6E2M3FNType())
- .Case("f6E3M2FN", b.getFloat6E3M2FNType())
- .Case("f8E5M2", b.getFloat8E5M2Type())
- .Case("f8E4M3", b.getFloat8E4M3Type())
- .Case("f8E4M3FN", b.getFloat8E4M3FNType())
- .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
- .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
- .Case("f8E3M4", b.getFloat8E3M4Type())
- .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
- .Case("bf16", b.getBF16Type())
- .Case("f16", b.getF16Type())
- .Case("f32", b.getF32Type())
- .Case("f64", b.getF64Type())
- .Case("f80", b.getF80Type())
- .Case("f128", b.getF128Type())
+ .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
+ .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
+ .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
+ .Case("f8E5M2", b.getType<Float8E5M2Type>())
+ .Case("f8E4M3", b.getType<Float8E4M3Type>())
+ .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
+ .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
+ .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
+ .Case("f8E3M4", b.getType<Float8E3M4Type>())
+ .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
+ .Case("bf16", b.getType<BFloat16Type>())
+ .Case("f16", b.getType<Float16Type>())
+ .Case("f32", b.getType<Float32Type>())
+ .Case("f64", b.getType<Float64Type>())
+ .Case("f80", b.getType<Float80Type>())
+ .Case("f128", b.getType<Float128Type>())
.Default(std::nullopt);
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8439b063f2634b..d57a7ca07ede58 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
-FloatType Builder::getFloat4E2M1FNType() {
- return Float4E2M1FNType::get(context);
-}
-
-FloatType Builder::getFloat6E2M3FNType() {
- return Float6E2M3FNType::get(context);
-}
-
-FloatType Builder::getFloat6E3M2FNType() {
- return Float6E3M2FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }
-
-FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }
-
-FloatType Builder::getFloat8E4M3FNType() {
- return Float8E4M3FNType::get(context);
-}
-
-FloatType Builder::getFloat8E5M2FNUZType() {
- return Float8E5M2FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3FNUZType() {
- return Float8E4M3FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E4M3B11FNUZType() {
- return Float8E4M3B11FNUZType::get(context);
-}
-
-FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }
-
-FloatType Builder::getFloat8E8M0FNUType() {
- return Float8E8M0FNUType::get(context);
-}
-
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b9e745fdf4a13e..11f3446689c81c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -1044,39 +1044,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
-Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
- return context->getImpl().f4E2M1FNTy;
-}
-Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
- return context->getImpl().f6E2M3FNTy;
-}
-Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
- return context->getImpl().f6E3M2FNTy;
-}
-Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
- return context->getImpl().f8E5M2Ty;
-}
-Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
- return context->getImpl().f8E4M3Ty;
-}
-Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3FNTy;
-}
-Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E5M2FNUZTy;
-}
-Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3FNUZTy;
-}
-Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
- return context->getImpl().f8E4M3B11FNUZTy;
-}
-Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
- return context->getImpl().f8E3M4Ty;
-}
-Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
- return context->getImpl().f8E8M0FNUTy;
-}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
|
@@ -1044,39 +1044,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) { | |||
/// This should not be used directly. | |||
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } | |||
|
|||
Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to remove the fields from the ContextImpl (and their initialization)
8c85f1f
to
2e1833f
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/35/builds/6360 Here is the relevant piece of the build log for the reference
|
Pulls in llvm/llvm-project#123200 which is useful and also handles #5664. Integrations were required due to llvm/llvm-project#123026, llvm/llvm-project#123321 and llvm/llvm-project#123326. Also closes #5685
Pulls in llvm/llvm-project#123200 which is useful and also handles triton-lang#5664. Integrations were required due to llvm/llvm-project#123026, llvm/llvm-project#123321 and llvm/llvm-project#123326. Also closes triton-lang#5685
Remove builder API (e.g.,
b.getFloat4E2M1FNType()
) and caching inMLIRContext
for low-precision FP types. Types are still cached in the type uniquer.For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28
Note for LLVM integration: Use
b.getType<Float4E2M1FNType>()
orFloat4E2M1FNType::get(b.getContext())
instead ofb.getFloat4E2M1FNType()
.