Skip to content

Add arith expansion of f8E8M0 type for extf/trunc ops #140332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

umangyadav
Copy link
Contributor

@umangyadav umangyadav commented May 17, 2025

F8E8M0 floating type is supposed to represent biased exponent bits of F32 type in OCP Micro scaling floating point formats.

https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR expands arith.truncf and arith.extf to support this behavior.

For the arith.truncf thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as 0xFF. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast.

cc: @krzysz00 @MaheshRavishankar @Muzammiluddin-Syed-ECE

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:arith labels May 17, 2025
@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Umang Yadav (umangyadav)

Changes

F8E8M0 floating type is supposed to represent unbiased exponent bits of F32 type in OCP Micro scaling floating point formats.

https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR expands arith.truncf and arith.extf to support this behavior.

For the arith.truncf thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as 0xFF. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast.


Full diff: https://github.com/llvm/llvm-project/pull/140332.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+5-3)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+129-9)
  • (modified) mlir/lib/IR/Types.cpp (+1-1)
  • (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+129-1)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
 def ArithExpandOpsPass : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let dependentDialects = ["vector::VectorDialect"];
-  let options = [
-    Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
-           "Enable the BF16 expansion patterns">,
+  let options =
+      [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+              "Enable the BF16 expansion patterns">,
+       Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+              "Enable the F8E8M0 expansion patterns">,
   ];
 }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isF8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Constant used to make the rounding bias.
     Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
     // Constant used to generate a quiet NaN.
-    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
     // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
     Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+    // create constants for NaNs
+    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.isBF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.isF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(resultTy);
+    if (!resultETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+    }
+    if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+    }
+
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+    if (!operandETy.isF32()) {
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    }
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
       arith::MinNumFOp
     >();
 
-    if (includeBf16) {
+    if(includeBf16) {
       arith::populateExpandBFloat16Patterns(patterns);
+    }
+    if(includeF8E8M0) {
+      arith::populateExpandF8E8M0Patterns(patterns);
+    }
+    if (includeBf16 || includeF8E8M0) {
       target.addDynamicallyLegalOp<arith::ExtFOp>(
-        [](arith::ExtFOp op) {
+        [=](arith::ExtFOp op) {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isBF16() && outETy.isF32());
+          if(includeBf16 && includeF8E8M0)
+            return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+          if(includeBf16)
+            return !(inETy.isBF16() && outETy.isF32());
+          return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
         });
 
       target.addDynamicallyLegalOp<arith::TruncFOp>(
-        [](arith::TruncFOp op)  {
+        [=](arith::TruncFOp op)  {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isF32() && outETy.isBF16());
+          if(includeBf16 && includeF8E8M0) 
+            return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
+          if(includeBf16)
+            return !(inETy.isF32() && outETy.isBF16());
+          return 
+            !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
         });
     }
-
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 //===----------------------------------------------------------------------===//
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
 // CHECK-LABEL: @truncf_vector_f32
 // CHECK-NOT: arith.truncf
 
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f16
+    return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
 // -----
 
 func.func @maxsi(%a: i32, %b: i32) -> i32 {

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but this does seem like a good fit for ExpandOps

@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I noticed invalid case style for varible name. Therefore changed it.

LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto operand = op.getOperand();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably be Value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}

if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd not hardcode a list of targets here - I'd just plow forward with a cast to f32 and if the final target has a bitwidth less than 32 you truncate and for > 32 you extend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (!resultETy.isF8E8M0FNU()) {
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
}
if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note: extend or truncate to f32 as needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();

if (includeBf16) {
if(includeBf16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the space back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
return !(inETy.isF32() && outETy.isBF16());
if(includeBf16 && includeF8E8M0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the condition can be simplified here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified.

@@ -109,6 +109,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isF8E8M0FNU() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the isF* for small types were recently removed: #123326

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went back to using llvm::isa. Thanks

@tgymnich
Copy link
Member

tgymnich commented May 17, 2025

When converting f32 to f8e8m0, should we map negative numbers to 0 (e.g. underflow to smallest normalized value)?

@krzysz00
Copy link
Contributor

As far as I'm aware, the cast down is meant to drop the sign but, not go to 0

Especially since f8E8M0FNU doesn't have a 0

@tgymnich
Copy link
Member

tgymnich commented May 17, 2025

Just dropping the sign makes sense to me.
I meant 0 as in smallest normalized value.

https://github.com/iree-org/iree/blob/c447638dae70fc21f5d84ad4cf402ca034a60cda/runtime/src/iree/base/internal/math.h#L596

@krzysz00 I assume this is wrong then and needs to be changed?

@krzysz00
Copy link
Contributor

Yeah, given that this meant to be an exponent part for scaling other floats by, I figure fabs() might be a better thing to have there

@bjacob
Copy link
Contributor

bjacob commented May 19, 2025

As discussed on the IREE issue, I believe that there is a longstanding tradition that if a conversion returns a finite value then it should be the nearest representable value, that is corroborated by table 3 in the OCP spec discussing the "overflow or saturate" semantics, that is broken by taking the absolute value.

Concretely here, if the input value is the result of a calculation that accidentally yielded a tiny negative value, clamping to the nearest representable value seems more useful than bouncing back towards larger magnitudes.

@krzysz00
Copy link
Contributor

Yeah, having slept on it, sending negative numbers to 0 / 2^-127 makes sense, mostly because it's very weird for truncf to mess with the ordering between floats

@umangyadav
Copy link
Contributor Author

As discussed on the IREE issue, I believe that there is a longstanding tradition that if a conversion returns a finite value then it should be the nearest representable value, that is corroborated by table 3 in the OCP spec discussing the "overflow or saturate" semantics, that is broken by taking the absolute value.

Table 3 is for FP8 types except F8E8M0.

FP8E8M0 is used for shared block scale which is infact calculated by taking extracting exponent bits of fabs(value.f32)). So it is not really a "conversion" or "cast" in conventional sense.

OCP Spec has this definition for Fp8E8M0
"E8M0 is an unsigned representation of a conventional biased Float32 exponent"

Here is one of the reference:
https://github.com/amd/Quark/blob/60cd6e46d20a5553a7b1a754c0459737f3c31fde/quark/onnx/operators/custom_ops/src/mx/cuda/mx_kernel.cu#L63

@bjacob
Copy link
Contributor

bjacob commented May 20, 2025

Right, so the specification is under-defined when it comes to E8M0 conversions. I'm just suggesting what I believe is the more important invariant to preserve: "if a conversion produces a finite value then it should be a nearest value". Also, what Krzysztof said about type conversions being monotonic functions is another important invariant to preserve. By contrast, dropping the sign bit seems arbitrary, except maybe if one thinks less in terms of the real numbers being encoded and more in terms of how they are encoded with or without a sign bit.

@krzysz00
Copy link
Contributor

As a side note, llvm::APFloat crashes on negative f8E8M0s

Re arith.truncf, it's worth noting that most operations that operate on a "f8E8M0-in-f32" (AMD's got a bunch of these) ignore the sign bit (that is, implicitly take a fabs()).

From the sources I'm seeing, the "natural" conversion function from f32 to f8E8M0 isn't really a floating point truncate, but

f8E8M0 @get_exponent(float %source) {
  %source.bits = bitcast float %source to i32
  %no.mantissa = lshr i32 %source.bits, 23
  %f8e8m0 = trunc i32 %no.mantissa to i8
  %was.nan = fcmp oeq %source, %source
  %result = select i8, %was.nan, i8 0xff, i8 %f8e8M0
  ret f8E8M0 %result
}

I suspect that we may want an operation that isn't named truncf here - something like arith.extract_exponent, but then that leaves us with truncf unimplemented for f8E8M0 ... which ... maybe that's true.

(That is to say, in practice, when we're writing out a buffer of scales after a post-matmul quantization, we'll want to do the exponent extraction, not a more principled truncf)

@umangyadav
Copy link
Contributor Author

umangyadav commented May 21, 2025

I am open to renaming arith.truncf to arith.extract_exponent but then i am not sure what i should name arith.extf : f8E8M0 to f32. Because, arith.extf of f8E8M0 to f32 would be a regular upcast in conventional way. So i'll have arith.extf but not corresponding arith.truncf for F32 to F8E8M0 but arith.extract_exponent for that.

Here is the python code from official spec implementation.

https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L49

which is doing max(fabs(x)) along with clamping later. (i argue that clamping is not necessary for f32 to f8E8M0 downcast since both have 8 bit exponents).

One thing i missed is that it is mapping zero.f32 to 2 ^ -126 and not 2 ^ -127 in f8E8M0. That allows to check for flush denorms later in the code here . I'll push a fix for that.

Zero would be exception to arith.extract_exponent semantics if we were to name it like that. One would expect exponent of Zero to be zero, but it won't be if specs are followed, it would be 0x01

@bjacob
Copy link
Contributor

bjacob commented May 21, 2025

Just want to make clear that I was only commenting from a perspective of trying to treat E8M0 as a real FP type.

Maybe that was pointless since that type is inherently only ever useful as a MX scale type. So it may make more sense to treat it as the specialized type that it is, which means, do whatever you feel is the established practice in this field.

I personally think that E8M0 is a mistake even within its intended application domain for MX scales, because in that domain, since scales are shared across many matrix elements, their bit-width doesn't matter that much, and accordingly, it would have been better so simply have, say, bf16 scales. I don't say this to thrown another wrench in this discussion, on the contrary: this means that in my mental model, E8M0 is probably a short-lived oddity so we may as well not think too hard about it.

@umangyadav
Copy link
Contributor Author

umangyadav commented May 21, 2025

One thing i missed is that it is mapping zero.f32 to 2 ^ -126 and not 2 ^ -127 in f8E8M0. That allows to check for flush denorms later in the code here . I'll push a fix for that.

Thinking more about it, i don't think i need to make any change for this.

shared_exponent is calculated as floor(log2(max(fabs(x)))) where max is taken along some block axis.

https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L76

Only possible way shared_exponent becomes zero is iff the block values are either zero or subnormal float values.

Quantization is carried out by doing A = A / (2 ^ shared_exponent). Therefore even if i were to map zero to 0x01 it will still result in zero as the quantized value.

Note that spec algorithm flushes denorms on Line 4 below. Therefore block with all denorms and zero will result in quantized value of zero.

https://arxiv.org/pdf/2310.10537

image

Therefore i think it is better to leave zero mapping to 0x00 to not complicate it.

CC: @krzysz00 @bjacob @tgymnich @dhernandez0 Let me know if this makes sense. I've already addressed review comments so far.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nitpick, otherwise seems fine to land for now (and we'll want to reopen the IREE PR that corresponds to the abs() method for quantizing to a scale).

Having grep'd for patterns involving TruncFOp, I couldn't find any canonicalizations or foldings that'd get messed up by this change

Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
if(includeBf16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit re space between if and the condition - surprised clang-format doesn't catch this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks

Comment on lines 400 to 407
Type i8Ty = b.getI8Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
i8Ty = shapedTy.clone(i8Ty);
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect something like

auto cloneIfShaped = [&](Type baseTy) -> Type {
  if (auto shapedTy = dyn_cast<ShapedType>(operandTy))
    return shapedTy.clone(baseTy);
  return baseTy;
};

Type i8Ty = cloneIfShaped(b.getI8Type());
Type i32Ty = cloneIfShaped(b.getI32Type());
Type f32Ty = cloneIfShaped(b.getF32Type());

Also, since you are using it in the above rewrite, this lambda can be made a static function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Made a static method to use in all places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:arith mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants