Skip to content

Add arith expansion of f8E8M0 type for extf/trunc ops #140332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);

/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);

/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);

Expand Down
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
let options = [
Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
"Enable the BF16 expansion patterns">,
let options =
[Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
"Enable the BF16 expansion patterns">,
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
"Enable the F8E8M0 expansion patterns">,
];
}

Expand Down
163 changes: 133 additions & 30 deletions mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}

/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
return shapedTy.clone(cloneTo);
}
return cloneTo;
}

namespace {

/// Expands CeilDivUIOp (n, m) into
Expand Down Expand Up @@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
}

Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
i16Ty = shapedTy.clone(i16Ty);
i32Ty = shapedTy.clone(i32Ty);
}
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());

Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
Expand Down Expand Up @@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}

Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
i16Ty = shapedTy.clone(i16Ty);
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());

// Algorithm borrowed from this excellent code:
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
Expand All @@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I noticed invalid case style for varible name. Therefore changed it.

// Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
Expand All @@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
Value normalCaseResult_i16 =
Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};

struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);

if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}

Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());

Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
// create constants for NaNs
Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);

Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);

Value isNan =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
// select for NaNs
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
result = b.create<arith::TruncFOp>(resultTy, result);
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
result = b.create<arith::ExtFOp>(resultTy, result);
}
rewriter.replaceOp(op, result);
return success();
}
};

/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
they all map to NaN in F8E8M0 Type.
*/
struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultTy = op.getType();
Type resultETy = getElementTypeOrSelf(resultTy);
if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
}

if (op.getRoundingmodeAttr()) {
return rewriter.notifyMatchFailure(
op, "only applicable to default rounding mode.");
}

Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());

if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
operand = b.create<arith::TruncFOp>(f32Ty, operand);
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
rewriter.replaceOp(op, result);
return success();
}
};

struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
Expand Down Expand Up @@ -353,20 +437,34 @@ struct ArithExpandOpsPass

if (includeBf16) {
arith::populateExpandBFloat16Patterns(patterns);
target.addDynamicallyLegalOp<arith::ExtFOp>(
[](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
return !(inETy.isBF16() && outETy.isF32());
});

target.addDynamicallyLegalOp<arith::TruncFOp>(
[](arith::TruncFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
return !(inETy.isF32() && outETy.isBF16());
});
}
if (includeF8E8M0) {
arith::populateExpandF8E8M0Patterns(patterns);
}

target.addDynamicallyLegalOp<arith::ExtFOp>(
[=](arith::ExtFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
if (includeBf16)
legalTypes &= !(inETy.isBF16() && outETy.isF32());
if (includeF8E8M0)
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
return legalTypes;
});

target.addDynamicallyLegalOp<arith::TruncFOp>(
[=](arith::TruncFOp op) {
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
bool legalTypes = true;
if (includeBf16)
legalTypes &= !(inETy.isF32() && outETy.isBF16());
if (includeF8E8M0)
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
return legalTypes;
});

// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
Expand All @@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}

void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
patterns.getContext());
}

void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
Expand Down
130 changes: 129 additions & 1 deletion mlir/test/Dialect/Arith/expand-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s

// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
Expand Down Expand Up @@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf

// -----
func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
return %0 : f8E8M0FNU
}
// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
// CHECK: return %[[RESULT]]

// -----

func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
return %0 : f8E8M0FNU
}
// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
// CHECK: return %[[RESULT]]

// -----

func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
%0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
return %0 : vector<4xf8E8M0FNU>
}

// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
// CHECK-NOT: arith.truncf

// -----

func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
%0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
return %0 : vector<4xf8E8M0FNU>
}

// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
// CHECK-NOT: arith.truncf

// -----

func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
%0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
return %0 : vector<4xf8E8M0FNU>
}

// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
// CHECK-NOT: arith.truncf


// -----
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
%0 = arith.extf %arg0 : f8E8M0FNU to f32
return %0 : f32
}

// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
// CHECK: return %[[RESULT]]

// -----

func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
%0 = arith.extf %arg0 : f8E8M0FNU to f16
return %0 : f16
}

// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
// CHECK: return %[[F16_RESULT]]

// -----

func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
// CHECK-NOT: arith.extf

// -----

func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
return %0 : vector<4xf16>
}

// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
// CHECK-NOT: arith.extf

// -----

func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
return %0 : vector<4xbf16>
}

// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
// CHECK-NOT: arith.extf


// -----

func.func @maxsi(%a: i32, %b: i32) -> i32 {
Expand Down