-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Add arith expansion of f8E8M0 type for extf/trunc ops #140332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-arith Author: Umang Yadav (umangyadav) ChangesF8E8M0 floating type is supposed to represent unbiased exponent bits of F32 type in OCP Micro scaling floating point formats. https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR expands For the Full diff: https://github.com/llvm/llvm-project/pull/140332.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
- let options = [
- Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
- "Enable the BF16 expansion patterns">,
+ let options =
+ [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+ "Enable the BF16 expansion patterns">,
+ Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+ "Enable the F8E8M0 expansion patterns">,
];
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
+ bool isF8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
- Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+ Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.isBF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.isF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!resultETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+ if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+ }
+
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+ if (!operandETy.isF32()) {
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ }
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+ Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();
- if (includeBf16) {
+ if(includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
+ }
+ if(includeF8E8M0) {
+ arith::populateExpandF8E8M0Patterns(patterns);
+ }
+ if (includeBf16 || includeF8E8M0) {
target.addDynamicallyLegalOp<arith::ExtFOp>(
- [](arith::ExtFOp op) {
+ [=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isBF16() && outETy.isF32());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+ if(includeBf16)
+ return !(inETy.isBF16() && outETy.isF32());
+ return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
});
target.addDynamicallyLegalOp<arith::TruncFOp>(
- [](arith::TruncFOp op) {
+ [=](arith::TruncFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
- return !(inETy.isF32() && outETy.isBF16());
+ if(includeBf16 && includeF8E8M0)
+ return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
+ if(includeBf16)
+ return !(inETy.isF32() && outETy.isBF16());
+ return
+ !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
});
}
-
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+ patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
//===----------------------------------------------------------------------===//
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+ %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+ return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+ %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+ return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+ %0 = arith.extf %arg0 : f8E8M0FNU to f16
+ return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
// -----
func.func @maxsi(%a: i32, %b: i32) -> i32 {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, but this does seem like a good fit for ExpandOps
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { | |||
// Constant used to make the rounding bias. | |||
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); | |||
// Constant used to generate a quiet NaN. | |||
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); | |||
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I noticed invalid case style for varible name. Therefore changed it.
LogicalResult matchAndRewrite(arith::ExtFOp op, | ||
PatternRewriter &rewriter) const final { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
auto operand = op.getOperand(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can probably be Value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU"); | ||
} | ||
|
||
if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd not hardcode a list of targets here - I'd just plow forward with a cast to f32 and if the final target has a bitwidth less than 32 you truncate and for > 32 you extend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (!resultETy.isF8E8M0FNU()) { | ||
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU"); | ||
} | ||
if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note: extend or truncate to f32 as needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass | |||
arith::MinNumFOp | |||
>(); | |||
|
|||
if (includeBf16) { | |||
if(includeBf16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the space back?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Type inETy = getElementTypeOrSelf(op.getOperand().getType()); | ||
Type outETy = getElementTypeOrSelf(op.getType()); | ||
return !(inETy.isF32() && outETy.isBF16()); | ||
if(includeBf16 && includeF8E8M0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the condition can be simplified here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified.
mlir/include/mlir/IR/Types.h
Outdated
@@ -109,6 +109,7 @@ class Type { | |||
// Convenience predicates. This is only for floating point types, | |||
// derived types should use isa/dyn_cast. | |||
bool isIndex() const; | |||
bool isF8E8M0FNU() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all the isF* for small types were recently removed: #123326
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went back to using llvm::isa
. Thanks
When converting f32 to f8e8m0, should we map negative numbers to 0 (e.g. underflow to smallest normalized value)? |
As far as I'm aware, the cast down is meant to drop the sign but, not go to 0 Especially since f8E8M0FNU doesn't have a 0 |
Just dropping the sign makes sense to me. @krzysz00 I assume this is wrong then and needs to be changed? |
Yeah, given that this meant to be an exponent part for scaling other floats by, I figure fabs() might be a better thing to have there |
As discussed on the IREE issue, I believe that there is a longstanding tradition that if a conversion returns a finite value then it should be the nearest representable value, that is corroborated by table 3 in the OCP spec discussing the "overflow or saturate" semantics, that is broken by taking the absolute value. Concretely here, if the input value is the result of a calculation that accidentally yielded a tiny negative value, clamping to the nearest representable value seems more useful than bouncing back towards larger magnitudes. |
Yeah, having slept on it, sending negative numbers to 0 / 2^-127 makes sense, mostly because it's very weird for truncf to mess with the ordering between floats |
Table 3 is for FP8 types except F8E8M0. FP8E8M0 is used for shared block scale which is infact calculated by taking extracting exponent bits of OCP Spec has this definition for Fp8E8M0 Here is one of the reference: |
Right, so the specification is under-defined when it comes to E8M0 conversions. I'm just suggesting what I believe is the more important invariant to preserve: "if a conversion produces a finite value then it should be a nearest value". Also, what Krzysztof said about type conversions being monotonic functions is another important invariant to preserve. By contrast, dropping the sign bit seems arbitrary, except maybe if one thinks less in terms of the real numbers being encoded and more in terms of how they are encoded with or without a sign bit. |
As a side note, Re From the sources I'm seeing, the "natural" conversion function from f32 to f8E8M0 isn't really a floating point truncate, but f8E8M0 @get_exponent(float %source) {
%source.bits = bitcast float %source to i32
%no.mantissa = lshr i32 %source.bits, 23
%f8e8m0 = trunc i32 %no.mantissa to i8
%was.nan = fcmp oeq %source, %source
%result = select i8, %was.nan, i8 0xff, i8 %f8e8M0
ret f8E8M0 %result
} I suspect that we may want an operation that isn't named (That is to say, in practice, when we're writing out a buffer of scales after a post-matmul quantization, we'll want to do the exponent extraction, not a more principled truncf) |
I am open to renaming Here is the python code from official spec implementation. which is doing One thing i missed is that it is mapping Zero would be exception to |
Just want to make clear that I was only commenting from a perspective of trying to treat E8M0 as a real FP type. Maybe that was pointless since that type is inherently only ever useful as a MX scale type. So it may make more sense to treat it as the specialized type that it is, which means, do whatever you feel is the established practice in this field. I personally think that E8M0 is a mistake even within its intended application domain for MX scales, because in that domain, since scales are shared across many matrix elements, their bit-width doesn't matter that much, and accordingly, it would have been better so simply have, say, bf16 scales. I don't say this to thrown another wrench in this discussion, on the contrary: this means that in my mental model, E8M0 is probably a short-lived oddity so we may as well not think too hard about it. |
Thinking more about it, i don't think i need to make any change for this.
Only possible way shared_exponent becomes zero is iff the block values are either zero or subnormal float values. Quantization is carried out by doing Note that spec algorithm flushes denorms on Line 4 below. Therefore block with all denorms and zero will result in quantized value of zero. https://arxiv.org/pdf/2310.10537 Therefore i think it is better to leave zero mapping to CC: @krzysz00 @bjacob @tgymnich @dhernandez0 Let me know if this makes sense. I've already addressed review comments so far. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nitpick, otherwise seems fine to land for now (and we'll want to reopen the IREE PR that corresponds to the abs() method for quantizing to a scale).
Having grep
'd for patterns involving TruncFOp
, I couldn't find any canonicalizations or foldings that'd get messed up by this change
Type inETy = getElementTypeOrSelf(op.getOperand().getType()); | ||
Type outETy = getElementTypeOrSelf(op.getType()); | ||
bool legalTypes = true; | ||
if(includeBf16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit re space between if
and the condition - surprised clang-format doesn't catch this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks
Type i8Ty = b.getI8Type(); | ||
Type i32Ty = b.getI32Type(); | ||
Type f32Ty = b.getF32Type(); | ||
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) { | ||
i8Ty = shapedTy.clone(i8Ty); | ||
i32Ty = shapedTy.clone(i32Ty); | ||
f32Ty = shapedTy.clone(f32Ty); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect something like
auto cloneIfShaped = [&](Type baseTy) -> Type {
if (auto shapedTy = dyn_cast<ShapedType>(operandTy))
return shapedTy.clone(baseTy);
return baseTy;
};
Type i8Ty = cloneIfShaped(b.getI8Type());
Type i32Ty = cloneIfShaped(b.getI32Type());
Type f32Ty = cloneIfShaped(b.getF32Type());
Also, since you are using it in the above rewrite, this lambda can be made a static function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Made a static method to use in all places.
F8E8M0 floating type is supposed to represent biased exponent bits of F32 type in OCP Micro scaling floating point formats.
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
This PR expands
arith.truncf
andarith.extf
to support this behavior.For the
arith.truncf
thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as0xFF
. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast.cc: @krzysz00 @MaheshRavishankar @Muzammiluddin-Syed-ECE