Skip to content

Commit f023da1

Browse files
[mlir][IR] Remove factory methods from FloatType (#123026)
This commit removes convenience methods from `FloatType` to make it independent of concrete interface implementations. See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361 Note for LLVM integration: Replace `FloatType::getF32(` with `Float32Type::get(` etc.
1 parent f711aa9 commit f023da1

29 files changed

+198
-297
lines changed

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ constexpr TypeBuilderFunc getModel<unsigned long long>() {
320320
template <>
321321
constexpr TypeBuilderFunc getModel<double>() {
322322
return [](mlir::MLIRContext *context) -> mlir::Type {
323-
return mlir::FloatType::getF64(context);
323+
return mlir::Float64Type::get(context);
324324
};
325325
}
326326
template <>
@@ -347,11 +347,11 @@ constexpr TypeBuilderFunc getModel<long double>() {
347347
static_assert(size == 16 || size == 10 || size == 8,
348348
"unsupported long double size");
349349
if constexpr (size == 16)
350-
return mlir::FloatType::getF128(context);
350+
return mlir::Float128Type::get(context);
351351
if constexpr (size == 10)
352-
return mlir::FloatType::getF80(context);
352+
return mlir::Float80Type::get(context);
353353
if constexpr (size == 8)
354-
return mlir::FloatType::getF64(context);
354+
return mlir::Float64Type::get(context);
355355
llvm_unreachable("failed static assert");
356356
};
357357
}
@@ -369,7 +369,7 @@ constexpr TypeBuilderFunc getModel<const long double *>() {
369369
template <>
370370
constexpr TypeBuilderFunc getModel<float>() {
371371
return [](mlir::MLIRContext *context) -> mlir::Type {
372-
return mlir::FloatType::getF32(context);
372+
return mlir::Float32Type::get(context);
373373
};
374374
}
375375
template <>

flang/lib/Lower/ConvertType.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ static mlir::Type genRealType(mlir::MLIRContext *context, int kind) {
3636
Fortran::common::TypeCategory::Real, kind)) {
3737
switch (kind) {
3838
case 2:
39-
return mlir::FloatType::getF16(context);
39+
return mlir::Float16Type::get(context);
4040
case 3:
41-
return mlir::FloatType::getBF16(context);
41+
return mlir::BFloat16Type::get(context);
4242
case 4:
43-
return mlir::FloatType::getF32(context);
43+
return mlir::Float32Type::get(context);
4444
case 8:
45-
return mlir::FloatType::getF64(context);
45+
return mlir::Float64Type::get(context);
4646
case 10:
47-
return mlir::FloatType::getF80(context);
47+
return mlir::Float80Type::get(context);
4848
case 16:
49-
return mlir::FloatType::getF128(context);
49+
return mlir::Float128Type::get(context);
5050
}
5151
}
5252
llvm_unreachable("REAL type translation not implemented");

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,17 @@ mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
105105
mlir::Type fir::FirOpBuilder::getRealType(int kind) {
106106
switch (kindMap.getRealTypeID(kind)) {
107107
case llvm::Type::TypeID::HalfTyID:
108-
return mlir::FloatType::getF16(getContext());
108+
return mlir::Float16Type::get(getContext());
109109
case llvm::Type::TypeID::BFloatTyID:
110-
return mlir::FloatType::getBF16(getContext());
110+
return mlir::BFloat16Type::get(getContext());
111111
case llvm::Type::TypeID::FloatTyID:
112-
return mlir::FloatType::getF32(getContext());
112+
return mlir::Float32Type::get(getContext());
113113
case llvm::Type::TypeID::DoubleTyID:
114-
return mlir::FloatType::getF64(getContext());
114+
return mlir::Float64Type::get(getContext());
115115
case llvm::Type::TypeID::X86_FP80TyID:
116-
return mlir::FloatType::getF80(getContext());
116+
return mlir::Float80Type::get(getContext());
117117
case llvm::Type::TypeID::FP128TyID:
118-
return mlir::FloatType::getF128(getContext());
118+
return mlir::Float128Type::get(getContext());
119119
default:
120120
fir::emitFatalError(mlir::UnknownLoc::get(getContext()),
121121
"unsupported type !fir.real<kind>");

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,7 +2367,7 @@ mlir::Value IntrinsicLibrary::genAcosd(mlir::Type resultType,
23672367
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
23682368
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
23692369
mlir::Value dfactor = builder.createRealConstant(
2370-
loc, mlir::FloatType::getF64(context), pi / llvm::APFloat(180.0));
2370+
loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
23712371
mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
23722372
mlir::Value arg = builder.create<mlir::arith::MulFOp>(loc, args[0], factor);
23732373
return getRuntimeCallGenerator("acos", ftype)(builder, loc, {arg});
@@ -2518,7 +2518,7 @@ mlir::Value IntrinsicLibrary::genAsind(mlir::Type resultType,
25182518
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
25192519
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
25202520
mlir::Value dfactor = builder.createRealConstant(
2521-
loc, mlir::FloatType::getF64(context), pi / llvm::APFloat(180.0));
2521+
loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
25222522
mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
25232523
mlir::Value arg = builder.create<mlir::arith::MulFOp>(loc, args[0], factor);
25242524
return getRuntimeCallGenerator("asin", ftype)(builder, loc, {arg});
@@ -2544,7 +2544,7 @@ mlir::Value IntrinsicLibrary::genAtand(mlir::Type resultType,
25442544
}
25452545
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
25462546
mlir::Value dfactor = builder.createRealConstant(
2547-
loc, mlir::FloatType::getF64(context), llvm::APFloat(180.0) / pi);
2547+
loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi);
25482548
mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
25492549
return builder.create<mlir::arith::MulFOp>(loc, atan, factor);
25502550
}
@@ -2569,7 +2569,7 @@ mlir::Value IntrinsicLibrary::genAtanpi(mlir::Type resultType,
25692569
}
25702570
llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi);
25712571
mlir::Value dfactor =
2572-
builder.createRealConstant(loc, mlir::FloatType::getF64(context), inv_pi);
2572+
builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi);
25732573
mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
25742574
return builder.create<mlir::arith::MulFOp>(loc, atan, factor);
25752575
}
@@ -3124,7 +3124,7 @@ mlir::Value IntrinsicLibrary::genCosd(mlir::Type resultType,
31243124
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
31253125
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
31263126
mlir::Value dfactor = builder.createRealConstant(
3127-
loc, mlir::FloatType::getF64(context), pi / llvm::APFloat(180.0));
3127+
loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
31283128
mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
31293129
mlir::Value arg = builder.create<mlir::arith::MulFOp>(loc, args[0], factor);
31303130
return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg});
@@ -4418,12 +4418,12 @@ IntrinsicLibrary::genIeeeCopySign(mlir::Type resultType,
44184418
mlir::FloatType yRealType =
44194419
mlir::dyn_cast<mlir::FloatType>(yRealVal.getType());
44204420

4421-
if (yRealType == mlir::FloatType::getBF16(builder.getContext())) {
4421+
if (yRealType == mlir::BFloat16Type::get(builder.getContext())) {
44224422
// Workaround: CopySignOp and BitcastOp don't work for kind 3 arg Y.
44234423
// This conversion should always preserve the sign bit.
44244424
yRealVal = builder.createConvert(
4425-
loc, mlir::FloatType::getF32(builder.getContext()), yRealVal);
4426-
yRealType = mlir::FloatType::getF32(builder.getContext());
4425+
loc, mlir::Float32Type::get(builder.getContext()), yRealVal);
4426+
yRealType = mlir::Float32Type::get(builder.getContext());
44274427
}
44284428

44294429
// Args have the same type.
@@ -4979,7 +4979,7 @@ mlir::Value IntrinsicLibrary::genIeeeReal(mlir::Type resultType,
49794979

49804980
assert(args.size() == 2);
49814981
mlir::Type i1Ty = builder.getI1Type();
4982-
mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
4982+
mlir::Type f32Ty = mlir::Float32Type::get(builder.getContext());
49834983
mlir::Value a = args[0];
49844984
mlir::Type aType = a.getType();
49854985

@@ -5179,7 +5179,7 @@ mlir::Value IntrinsicLibrary::genIeeeRem(mlir::Type resultType,
51795179
mlir::Value x = args[0];
51805180
mlir::Value y = args[1];
51815181
if (mlir::dyn_cast<mlir::FloatType>(resultType).getWidth() < 32) {
5182-
mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
5182+
mlir::Type f32Ty = mlir::Float32Type::get(builder.getContext());
51835183
x = builder.create<fir::ConvertOp>(loc, f32Ty, x);
51845184
y = builder.create<fir::ConvertOp>(loc, f32Ty, y);
51855185
} else {
@@ -5213,7 +5213,7 @@ mlir::Value IntrinsicLibrary::genIeeeRint(mlir::Type resultType,
52135213
}
52145214
if (mlir::cast<mlir::FloatType>(resultType).getWidth() == 16)
52155215
a = builder.create<fir::ConvertOp>(
5216-
loc, mlir::FloatType::getF32(builder.getContext()), a);
5216+
loc, mlir::Float32Type::get(builder.getContext()), a);
52175217
mlir::Value result = builder.create<fir::ConvertOp>(
52185218
loc, resultType, genRuntimeCall("nearbyint", a.getType(), a));
52195219
if (isStaticallyPresent(args[1])) {
@@ -5298,10 +5298,10 @@ mlir::Value IntrinsicLibrary::genIeeeSignbit(mlir::Type resultType,
52985298
mlir::Value realVal = args[0];
52995299
mlir::FloatType realType = mlir::dyn_cast<mlir::FloatType>(realVal.getType());
53005300
int bitWidth = realType.getWidth();
5301-
if (realType == mlir::FloatType::getBF16(builder.getContext())) {
5301+
if (realType == mlir::BFloat16Type::get(builder.getContext())) {
53025302
// Workaround: can't bitcast or convert real(3) to integer(2) or real(2).
53035303
realVal = builder.createConvert(
5304-
loc, mlir::FloatType::getF32(builder.getContext()), realVal);
5304+
loc, mlir::Float32Type::get(builder.getContext()), realVal);
53055305
bitWidth = 32;
53065306
}
53075307
mlir::Type intType = builder.getIntegerType(bitWidth);
@@ -6065,7 +6065,7 @@ mlir::Value IntrinsicLibrary::genModulo(mlir::Type resultType,
60656065
auto fastMathFlags = builder.getFastMathFlags();
60666066
// F128 arith::RemFOp may be lowered to a runtime call that may be unsupported
60676067
// on the target, so generate a call to Fortran Runtime's ModuloReal16.
6068-
if (resultType == mlir::FloatType::getF128(builder.getContext()) ||
6068+
if (resultType == mlir::Float128Type::get(builder.getContext()) ||
60696069
(fastMathFlags & mlir::arith::FastMathFlags::ninf) ==
60706070
mlir::arith::FastMathFlags::none)
60716071
return builder.createConvert(
@@ -6254,7 +6254,7 @@ mlir::Value IntrinsicLibrary::genNearest(mlir::Type resultType,
62546254
mlir::FloatType yType = mlir::dyn_cast<mlir::FloatType>(args[1].getType());
62556255
const unsigned yBitWidth = yType.getWidth();
62566256
if (xType != yType) {
6257-
mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
6257+
mlir::Type f32Ty = mlir::Float32Type::get(builder.getContext());
62586258
if (xBitWidth < 32)
62596259
x1 = builder.createConvert(loc, f32Ty, x1);
62606260
if (yBitWidth > 32 && yBitWidth > xBitWidth)
@@ -7205,7 +7205,7 @@ mlir::Value IntrinsicLibrary::genSind(mlir::Type resultType,
72057205
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
72067206
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
72077207
mlir::Value dfactor = builder.createRealConstant(
7208-
loc, mlir::FloatType::getF64(context), pi / llvm::APFloat(180.0));
7208+
loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
72097209
mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
72107210
mlir::Value arg = builder.create<mlir::arith::MulFOp>(loc, args[0], factor);
72117211
return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg});
@@ -7286,7 +7286,7 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
72867286
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
72877287
llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
72887288
mlir::Value dfactor = builder.createRealConstant(
7289-
loc, mlir::FloatType::getF64(context), pi / llvm::APFloat(180.0));
7289+
loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
72907290
mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
72917291
mlir::Value arg = builder.create<mlir::arith::MulFOp>(loc, args[0], factor);
72927292
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});

flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
15791579

15801580
return callOp.getResult(0);
15811581
} else if (width == 64) {
1582-
auto fTy{mlir::FloatType::getF64(context)};
1582+
auto fTy{mlir::Float64Type::get(context)};
15831583
auto ty{mlir::VectorType::get(2, fTy)};
15841584

15851585
// vec_vtf(arg1, arg2) = fmul(1.0 / (1 << arg2), llvm.sitofp(arg1))
@@ -1639,7 +1639,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
16391639
newArgs[0] =
16401640
builder.create<fir::CallOp>(loc, funcOp, newArgs).getResult(0);
16411641
auto fvf32Ty{newArgs[0].getType()};
1642-
auto f32type{mlir::FloatType::getF32(context)};
1642+
auto f32type{mlir::Float32Type::get(context)};
16431643
auto mvf32Ty{mlir::VectorType::get(4, f32type)};
16441644
newArgs[0] = builder.createConvert(loc, mvf32Ty, newArgs[0]);
16451645

@@ -1949,7 +1949,7 @@ PPCIntrinsicLibrary::genVecLdCallGrp(mlir::Type resultType,
19491949
fname = isBEVecElemOrderOnLE() ? "llvm.ppc.vsx.lxvd2x.be"
19501950
: "llvm.ppc.vsx.lxvd2x";
19511951
// llvm.ppc.altivec.lxvd2x* returns <2 x double>
1952-
intrinResTy = mlir::VectorType::get(2, mlir::FloatType::getF64(context));
1952+
intrinResTy = mlir::VectorType::get(2, mlir::Float64Type::get(context));
19531953
} break;
19541954
case VecOp::Xlw4:
19551955
fname = isBEVecElemOrderOnLE() ? "llvm.ppc.vsx.lxvw4x.be"
@@ -2092,7 +2092,7 @@ PPCIntrinsicLibrary::genVecPerm(mlir::Type resultType,
20922092
auto mlirTy{vecTyInfo.toMlirVectorType(context)};
20932093

20942094
auto vi32Ty{mlir::VectorType::get(4, mlir::IntegerType::get(context, 32))};
2095-
auto vf64Ty{mlir::VectorType::get(2, mlir::FloatType::getF64(context))};
2095+
auto vf64Ty{mlir::VectorType::get(2, mlir::Float64Type::get(context))};
20962096

20972097
auto mArg0{builder.createConvert(loc, mlirTy, argBases[0])};
20982098
auto mArg1{builder.createConvert(loc, mlirTy, argBases[1])};

0 commit comments

Comments
 (0)