Skip to content

[mlir][IR] Treat tf32 as 19-bit float #116738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 19, 2024

TF32 is a variant of F32 that is truncated to 19 bits. There used to be special handling in FloatType::getWidth() so that TF32 was treated as a 32-bit float in some places. (Some places use FloatType::getWidth, others directly query the APFloat semantics.) This caused problems because FloatType::getWidth did not agree with the underlying APFloat semantics.

In particular, creating an elements attr / array attr with tf32 element type crashed. E.g.:

"foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()

mlir-opt: llvm-project/llvm/lib/Support/APFloat.cpp:4108: void llvm::detail::IEEEFloat::initFromAPInt(const fltSemantics *, const APInt &): Assertion `api.getBitWidth() == Sem->sizeInBits' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
"foo"() {f32attr = array<tf32: 1024.>} : () -> ()

mlir-opt: llvm-project/mlir/lib/AsmParser/AttributeParser.cpp:847: void (anonymous namespace)::DenseArrayElementParser::append(const APInt &): Assertion `data.getBitWidth() % 8 == 0' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.

It is unclear why the special handling for TF32 is needed. For reference: #107372

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

TF32 is a variant of F32 that is truncated to 19 bits. There used to be special handling in FloatType::getWidth() such that TF32 was treated as a 32-bit float in some places. (Some places use FloatType::getWidth, others directly query the APFloat semantics.) This caused problems because FloatType::getWidth did not agree with the underlying APFloat semantics.

In particular, creating an elements attr / array attr with tf32 element type crashed. E.g.:

"foo"() {attr = dense&lt;4.0&gt; : tensor&lt;tf32&gt;} : () -&gt; ()

mlir-opt: llvm-project/llvm/lib/Support/APFloat.cpp:4108: void llvm::detail::IEEEFloat::initFromAPInt(const fltSemantics *, const APInt &amp;): Assertion `api.getBitWidth() == Sem-&gt;sizeInBits' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
"foo"() {f32attr = array&lt;tf32: 1024.&gt;} : () -&gt; ()

mlir-opt: llvm-project/mlir/lib/AsmParser/AttributeParser.cpp:847: void (anonymous namespace)::DenseArrayElementParser::append(const APInt &amp;): Assertion `data.getBitWidth() % 8 == 0' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.

It is unclear why the special handling for TF32 is needed. For reference: #105573


Full diff: https://github.com/llvm/llvm-project/pull/116738.diff

2 Files Affected:

  • (modified) mlir/lib/IR/BuiltinTypes.cpp (-5)
  • (modified) mlir/test/IR/attribute.mlir (+16)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 25e9f80c9963cb..e8e8f3cdfbfd73 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,11 +91,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  // The actual width of TF32 is 19 bits. However, since it is a truncated
-  // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
-  // for compatibility.
-  if (llvm::isa<FloatTF32Type>(*this))
-    return 32;
   return APFloat::semanticsSizeInBits(getFloatSemantics());
 }
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..0085d64ae82b6b 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -561,6 +561,14 @@ func.func @correct_type_pass() {
 
 // -----
 
+func.func @tf32_elements_attr() {
+  // CHECK: "foo"() {attr = dense<4.000000e+00> : tensor<tf32>} : () -> ()
+  "foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test StringElementsAttr
 //===----------------------------------------------------------------------===//
@@ -675,6 +683,14 @@ func.func @dense_array_attr() attributes {
 
 // -----
 
+func.func @test_invalid_bitwidth_type() {
+  // expected-error @below{{element type bitwidth must be a multiple of 8}}
+  "foo"() {tf32attr = array<tf32: 1024.0>} : () -> ()
+  return
+}
+
+// -----
+
 func.func @testConfinedDenseArrayAttr() {
   "test.confined_dense_array_attr"() {
     i64attr = array<i64: 0, 2, 3>,

@River707
Copy link
Contributor

@jpienaar Do you remember the reason why this is special cased? Is there somewhere in the TF codebase that is depending on this behavior? I'd love to remove all of the special cases just rely on APFloat for all of the semantics here.

@sergey-kozub
Copy link
Contributor

You added a test with an expected failure for array. Does this change mean users cannot create arrays of TF32 values anymore, and can use only TF32 scalars?

@matthias-springer
Copy link
Member Author

You added a test with an expected failure for array. Does this change mean users cannot create arrays of TF32 values anymore, and can use only TF32 scalars?

It also didn’t work before. We used to crash, now there is a proper error.

Copy link
Contributor

@sergey-kozub sergey-kozub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed this strange piece of code before, it doesn't make sense to me, so if no tests are failing then LGTM.

@matthias-springer matthias-springer merged commit 67a1fdb into main Nov 20, 2024
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/tf32_bitwidth branch November 20, 2024 08:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants