Skip to content

Commit 67a1fdb

Browse files
[mlir][IR] Treat tf32 as 19-bit float (#116738)
TF32 is a variant of F32 that is truncated to 19 bits. There used to be special handling in `FloatType::getWidth()` so that TF32 was treated as a 32-bit float in some places. (Some places use `FloatType::getWidth`, others directly query the `APFloat` semantics.) This caused problems because `FloatType::getWidth` did not agree with the underlying `APFloat` semantics. In particular, creating an elements attr / array attr with `tf32` element type crashed. E.g.: ``` "foo"() {attr = dense<4.0> : tensor<tf32>} : () -> () mlir-opt: llvm-project/llvm/lib/Support/APFloat.cpp:4108: void llvm::detail::IEEEFloat::initFromAPInt(const fltSemantics *, const APInt &): Assertion `api.getBitWidth() == Sem->sizeInBits' failed. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. ``` ``` "foo"() {f32attr = array<tf32: 1024.>} : () -> () mlir-opt: llvm-project/mlir/lib/AsmParser/AttributeParser.cpp:847: void (anonymous namespace)::DenseArrayElementParser::append(const APInt &): Assertion `data.getBitWidth() % 8 == 0' failed. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. ``` It is unclear why the special handling for TF32 is needed. For reference: #107372
1 parent 3a5cf6d commit 67a1fdb

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
9191
//===----------------------------------------------------------------------===//
9292

9393
unsigned FloatType::getWidth() {
94-
// The actual width of TF32 is 19 bits. However, since it is a truncated
95-
// version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth
96-
// for compatibility.
97-
if (llvm::isa<FloatTF32Type>(*this))
98-
return 32;
9994
return APFloat::semanticsSizeInBits(getFloatSemantics());
10095
}
10196

mlir/test/IR/attribute.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,14 @@ func.func @correct_type_pass() {
561561

562562
// -----
563563

564+
func.func @tf32_elements_attr() {
565+
// CHECK: "foo"() {attr = dense<4.000000e+00> : tensor<tf32>} : () -> ()
566+
"foo"() {attr = dense<4.0> : tensor<tf32>} : () -> ()
567+
return
568+
}
569+
570+
// -----
571+
564572
//===----------------------------------------------------------------------===//
565573
// Test StringElementsAttr
566574
//===----------------------------------------------------------------------===//
@@ -675,6 +683,14 @@ func.func @dense_array_attr() attributes {
675683

676684
// -----
677685

686+
func.func @test_invalid_bitwidth_type() {
687+
// expected-error @below{{element type bitwidth must be a multiple of 8}}
688+
"foo"() {tf32attr = array<tf32: 1024.0>} : () -> ()
689+
return
690+
}
691+
692+
// -----
693+
678694
func.func @testConfinedDenseArrayAttr() {
679695
"test.confined_dense_array_attr"() {
680696
i64attr = array<i64: 0, 2, 3>,

0 commit comments

Comments
 (0)