Skip to content

Commit 3db8ed0

Browse files
[mlir][arith] Add support for fptosi, fptoui to ArithToAPFloat (#169277)
Add support for `arith.fptosi` and `arith.fptoui`.
1 parent 196f6de commit 3db8ed0

File tree

4 files changed

+116
-2
lines changed

4 files changed

+116
-2
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,8 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
159159
Location loc = op.getLoc();
160160
auto inFloatTy = cast<FloatType>(op.getOperand().getType());
161161
auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
162-
auto int64Type = rewriter.getI64Type();
163162
Value operandBits = arith::ExtUIOp::create(
164-
rewriter, loc, int64Type,
163+
rewriter, loc, i64Type,
165164
arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
166165

167166
// Call APFloat function.
@@ -185,6 +184,63 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
185184
SymbolOpInterface symTable;
186185
};
187186

187+
template <typename OpTy>
188+
struct FpToIntConversion final : OpRewritePattern<OpTy> {
189+
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
190+
bool isUnsigned, PatternBenefit benefit = 1)
191+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
192+
isUnsigned(isUnsigned) {}
193+
194+
LogicalResult matchAndRewrite(OpTy op,
195+
PatternRewriter &rewriter) const override {
196+
if (op.getType().getIntOrFloatBitWidth() > 64)
197+
return rewriter.notifyMatchFailure(
198+
op, "result type > 64 bits is not supported");
199+
200+
// Get APFloat function from runtime library.
201+
auto i1Type = IntegerType::get(symTable->getContext(), 1);
202+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
203+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
204+
FailureOr<FuncOp> fn =
205+
lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
206+
{i32Type, i32Type, i1Type, i64Type});
207+
if (failed(fn))
208+
return fn;
209+
210+
rewriter.setInsertionPoint(op);
211+
// Cast operands to 64-bit integers.
212+
Location loc = op.getLoc();
213+
auto inFloatTy = cast<FloatType>(op.getOperand().getType());
214+
auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
215+
Value operandBits = arith::ExtUIOp::create(
216+
rewriter, loc, i64Type,
217+
arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
218+
219+
// Call APFloat function.
220+
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
221+
auto outIntTy = cast<IntegerType>(op.getType());
222+
Value outWidthValue = arith::ConstantOp::create(
223+
rewriter, loc, i32Type,
224+
rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
225+
Value isUnsignedValue = arith::ConstantOp::create(
226+
rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
227+
SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
228+
operandBits};
229+
auto resultOp =
230+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
231+
SymbolRefAttr::get(*fn), params);
232+
233+
// Truncate result to the original width.
234+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
235+
resultOp->getResult(0));
236+
rewriter.replaceOp(op, truncatedBits);
237+
return success();
238+
}
239+
240+
SymbolOpInterface symTable;
241+
bool isUnsigned;
242+
};
243+
188244
namespace {
189245
struct ArithToAPFloatConversionPass final
190246
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -209,6 +265,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
209265
patterns
210266
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
211267
context, getOperation());
268+
patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
269+
/*isUnsigned=*/false);
270+
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
271+
/*isUnsigned=*/true);
212272
LogicalResult result = success();
213273
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
214274
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
// APFloatBase::Semantics enum value.
2121
//
2222
#include "llvm/ADT/APFloat.h"
23+
#include "llvm/ADT/APSInt.h"
2324

2425
#ifdef _WIN32
2526
#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
@@ -101,4 +102,21 @@ _mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) {
101102
llvm::APInt result = val.bitcastToAPInt();
102103
return result.getZExtValue();
103104
}
105+
106+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
107+
int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) {
108+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
109+
static_cast<llvm::APFloatBase::Semantics>(semantics));
110+
unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
111+
llvm::APFloat val(sem, llvm::APInt(inputWidth, a));
112+
llvm::APSInt result(resultWidth, isUnsigned);
113+
bool isExact;
114+
// TODO: Custom rounding modes are not supported yet.
115+
val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact);
116+
// This function always returns uint64_t, regardless of the desired result
117+
// width. It does not matter whether we zero-extend or sign-extend the APSInt
118+
// to 64 bits because the generated IR in arith-to-apfloat will truncate the
119+
// result to the desired result width.
120+
return result.getZExtValue();
121+
}
104122
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,29 @@ func.func @truncf(%arg0: bf16) {
148148
%0 = arith.truncf %arg0 : bf16 to f4E2M1FN
149149
return
150150
}
151+
152+
// -----
153+
154+
// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
155+
// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
156+
// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
157+
// CHECK: %[[is_unsigned:.*]] = arith.constant false
158+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
159+
// CHECK: arith.trunci %[[res]] : i64 to i4
160+
func.func @fptosi(%arg0: f16) {
161+
%0 = arith.fptosi %arg0 : f16 to i4
162+
return
163+
}
164+
165+
// -----
166+
167+
// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
168+
// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
169+
// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
170+
// CHECK: %[[is_unsigned:.*]] = arith.constant true
171+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
172+
// CHECK: arith.trunci %[[res]] : i64 to i4
173+
func.func @fptoui(%arg0: f16) {
174+
%0 = arith.fptoui %arg0 : f16 to i4
175+
return
176+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,15 @@ func.func @entry() {
4343
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
4444
vector.print %cvt : f8E4M3FN
4545

46+
// CHECK-NEXT: 1
47+
// Bit pattern: 01, interpreted as signed integer: 1
48+
%cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2
49+
vector.print %cvt_int_signed : i2
50+
51+
// CHECK-NEXT: -2
52+
// Bit pattern: 10, interpreted as signed integer: -2
53+
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
54+
vector.print %cvt_int_unsigned : i2
55+
4656
return
4757
}

0 commit comments

Comments
 (0)