Skip to content

Commit 880d370

Browse files
authored
[APFloat] Add APFloat support for FP4 data type (#95392)
This patch adds APFloat type support for the E2M1 FP4 datatype. The definitions for this format are detailed in section 5.3.3 of the OCP specification, which can be accessed here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent e83adfe commit 880d370

File tree

4 files changed

+283
-7
lines changed

4 files changed

+283
-7
lines changed

clang/lib/AST/MicrosoftMangle.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
901901
case APFloat::S_FloatTF32:
902902
case APFloat::S_Float6E3M2FN:
903903
case APFloat::S_Float6E2M3FN:
904+
case APFloat::S_Float4E2M1FN:
904905
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
905906
}
906907

llvm/include/llvm/ADT/APFloat.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ struct APFloatBase {
197197
// types, there are no infinity or NaN values. The format is detailed in
198198
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
199199
S_Float6E2M3FN,
200+
// 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754
201+
// types, there are no infinity or NaN values. The format is detailed in
202+
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
203+
S_Float4E2M1FN,
200204

201205
S_x87DoubleExtended,
202206
S_MaxSemantics = S_x87DoubleExtended,
@@ -219,6 +223,7 @@ struct APFloatBase {
219223
static const fltSemantics &FloatTF32() LLVM_READNONE;
220224
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
221225
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
226+
static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
222227
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
223228

224229
/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -639,6 +644,7 @@ class IEEEFloat final : public APFloatBase {
639644
APInt convertFloatTF32APFloatToAPInt() const;
640645
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
641646
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
647+
APInt convertFloat4E2M1FNAPFloatToAPInt() const;
642648
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
643649
template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
644650
void initFromHalfAPInt(const APInt &api);
@@ -656,6 +662,7 @@ class IEEEFloat final : public APFloatBase {
656662
void initFromFloatTF32APInt(const APInt &api);
657663
void initFromFloat6E3M2FNAPInt(const APInt &api);
658664
void initFromFloat6E2M3FNAPInt(const APInt &api);
665+
void initFromFloat4E2M1FNAPInt(const APInt &api);
659666

660667
void assign(const IEEEFloat &);
661668
void copySignificand(const IEEEFloat &);
@@ -1067,6 +1074,7 @@ class APFloat : public APFloatBase {
10671074
// Below Semantics do not support {NaN or Inf}
10681075
case APFloat::S_Float6E3M2FN:
10691076
case APFloat::S_Float6E2M3FN:
1077+
case APFloat::S_Float4E2M1FN:
10701078
return false;
10711079
}
10721080
}

llvm/lib/Support/APFloat.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ enum class fltNonfiniteBehavior {
6969
// encodings do not distinguish between signalling and quiet NaN.
7070
NanOnly,
7171

72-
// This behavior is present in Float6E3M2FN and Float6E2M3FN types,
73-
// which do not support Inf or NaN values.
72+
// This behavior is present in Float6E3M2FN, Float6E2M3FN, and
73+
// Float4E2M1FN types, which do not support Inf or NaN values.
7474
FiniteOnly,
7575
};
7676

@@ -147,6 +147,8 @@ static constexpr fltSemantics semFloat6E3M2FN = {
147147
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
148148
static constexpr fltSemantics semFloat6E2M3FN = {
149149
2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly};
150+
static constexpr fltSemantics semFloat4E2M1FN = {
151+
2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly};
150152
static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
151153
static constexpr fltSemantics semBogus = {0, 0, 0, 0};
152154

@@ -218,6 +220,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
218220
return Float6E3M2FN();
219221
case S_Float6E2M3FN:
220222
return Float6E2M3FN();
223+
case S_Float4E2M1FN:
224+
return Float4E2M1FN();
221225
case S_x87DoubleExtended:
222226
return x87DoubleExtended();
223227
}
@@ -254,6 +258,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
254258
return S_Float6E3M2FN;
255259
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
256260
return S_Float6E2M3FN;
261+
else if (&Sem == &llvm::APFloat::Float4E2M1FN())
262+
return S_Float4E2M1FN;
257263
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
258264
return S_x87DoubleExtended;
259265
else
@@ -278,6 +284,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
278284
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
279285
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
280286
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
287+
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
281288
const fltSemantics &APFloatBase::x87DoubleExtended() {
282289
return semX87DoubleExtended;
283290
}
@@ -3640,6 +3647,11 @@ APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const {
36403647
return convertIEEEFloatToAPInt<semFloat6E2M3FN>();
36413648
}
36423649

3650+
APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const {
3651+
assert(partCount() == 1);
3652+
return convertIEEEFloatToAPInt<semFloat4E2M1FN>();
3653+
}
3654+
36433655
// This function creates an APInt that is just a bit map of the floating
36443656
// point constant as it would appear in memory. It is not a conversion,
36453657
// and treating the result as a normal integer is unlikely to be useful.
@@ -3687,6 +3699,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
36873699
if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN)
36883700
return convertFloat6E2M3FNAPFloatToAPInt();
36893701

3702+
if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN)
3703+
return convertFloat4E2M1FNAPFloatToAPInt();
3704+
36903705
assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
36913706
"unknown format!");
36923707
return convertF80LongDoubleAPFloatToAPInt();
@@ -3911,6 +3926,10 @@ void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) {
39113926
initFromIEEEAPInt<semFloat6E2M3FN>(api);
39123927
}
39133928

3929+
void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) {
3930+
initFromIEEEAPInt<semFloat4E2M1FN>(api);
3931+
}
3932+
39143933
/// Treat api as containing the bits of a floating point number.
39153934
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
39163935
assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3944,6 +3963,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
39443963
return initFromFloat6E3M2FNAPInt(api);
39453964
if (Sem == &semFloat6E2M3FN)
39463965
return initFromFloat6E2M3FNAPInt(api);
3966+
if (Sem == &semFloat4E2M1FN)
3967+
return initFromFloat4E2M1FNAPInt(api);
39473968

39483969
llvm_unreachable(nullptr);
39493970
}

0 commit comments

Comments
 (0)