Skip to content

Commit f5f39b3

Browse files
authored
Merge pull request #27 from vkuzo/20250201_llvm_rounding_v2
more comprehensive probing
2 parents e35c1b8 + d87fee5 commit f5f39b3

File tree

3 files changed

+149
-34
lines changed

3 files changed

+149
-34
lines changed

llvm_e8m0/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# probing llvm's APFloat8 e8m0 data type's rounding behavior
2+
3+
A script to verify that LLVM's float32 -> e8m0 casts uses round to nearest, ties to even rounding - it does!
4+
5+
To use E8M0 in LLVM you need LLVM v20.0+, I had to build LLVM from source since latest released version was 19.x.
6+
7+
To run:
8+
9+
```
10+
./run.sh
11+
```

llvm_e8m0/apfloat_example.cpp

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ using namespace llvm;
1212
// PR adding this functionality, for context: https://github.com/llvm/llvm-project/pull/107127
1313
// this isn't in released LLVM versions, so need to build from source to test it out
1414

15+
// access bits of a float
16+
union float_or_uint32_t {
17+
float f;
18+
uint32_t i;
19+
};
20+
1521

1622
void int_exponent_to_e8m0(int unbiased_exponent, const std::string& description) {
1723
std::cout << "description: " << description << ", unbiased exponent: " << unbiased_exponent << std::endl;
@@ -35,45 +41,98 @@ void int_exponent_to_e8m0(int unbiased_exponent, const std::string& description)
3541
std::cout << std::endl;
3642
}
3743

38-
void float32_val_to_e8m0(float val, const std::string& description) {
44+
float float32_val_to_e8m0(
45+
float val,
46+
const std::string& description,
47+
bool debug=false
48+
) {
49+
3950
APFloat val_float32(val); // 32-bit float
40-
std::cout << "description: " << description << ", float32 val: " << val_float32.convertToFloat() << std::endl;
51+
if (debug) {
52+
std::cout << "description: " << description << ", float32 val: " << val_float32.convertToFloat() << std::endl;
53+
}
4154

4255
APFloat::Semantics Sem = APFloat::S_Float8E8M0FNU;
4356
const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem);
4457

4558
uint64_t raw_bits_fp32 = val_float32.bitcastToAPInt().getZExtValue();
46-
std::cout << " fp32_org bits: " << std::bitset<32>(raw_bits_fp32) << "\n";
59+
if (debug) {
60+
std::cout << " fp32_org bits: " << std::bitset<32>(raw_bits_fp32) << "\n";
61+
}
4762

4863
bool losesInfo;
4964
val_float32.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
5065

51-
// print the raw bits
52-
uint64_t raw_bits_e8m0 = val_float32.bitcastToAPInt().getZExtValue();
53-
std::cout << " e8m0 bits: " << std::bitset<8>(raw_bits_e8m0) << "\n";
66+
if (debug) {
67+
// print the raw bits
68+
uint64_t raw_bits_e8m0 = val_float32.bitcastToAPInt().getZExtValue();
69+
std::cout << " e8m0 bits: " << std::bitset<8>(raw_bits_e8m0) << "\n";
70+
}
5471

5572
float val_float32_e8m0_float32 = val_float32.convertToFloat();
5673

57-
// access bits of a float
58-
union {
59-
float f;
60-
uint32_t i;
61-
} u;
62-
u.f = val_float32_e8m0_float32;
63-
// uint64_t raw_bits_e8m0_fp32 = val_float32_e8m0_float32.bitcastToAPInt().getZExtValue();
64-
std::cout << " fp32_new bits: " << std::bitset<32>(u.i) << "\n";
74+
if (debug) {
75+
// access bits of a float
76+
union float_or_uint32_t u;
77+
u.f = val_float32_e8m0_float32;
78+
std::cout << " fp32_new bits: " << std::bitset<32>(u.i) << "\n";
6579

66-
std::cout << " e8m0 -> float32 cast result: " << val_float32_e8m0_float32 << std::endl;
67-
std::cout << std::endl;
80+
std::cout << " e8m0 -> float32 cast result: " << val_float32_e8m0_float32 << std::endl;
81+
std::cout << std::endl;
82+
}
83+
84+
return val_float32_e8m0_float32;
85+
}
6886

87+
void fp32_e8m0_fp32_check_all_grs(const uint8_t exponent) {
88+
union float_or_uint32_t u;
89+
// iterate through the 8 possible combinations of GRS
90+
for (uint8_t grs = 0; grs < 8; grs++) {
91+
// create a float32 number via bit manipulation
92+
u.i = (0b0 << 31) | (exponent << 23) | (grs << (23 - 3));
93+
// TODO NaNs
94+
bool expect_round_up;
95+
uint8_t g = grs >> 2;
96+
uint8_t r = (grs & 0b010) >> 1;
97+
uint8_t s = grs & 0b1;
98+
99+
if (g == 0) {
100+
expect_round_up = false;
101+
} else {
102+
if ((r == 1) | (s == 1)) {
103+
expect_round_up = true;
104+
} else {
105+
if (exponent > 0) {
106+
// normal, round up
107+
expect_round_up = true;
108+
} else {
109+
// denormal, round down
110+
expect_round_up = false;
111+
}
112+
}
113+
}
114+
115+
const auto description = expect_round_up ? "round up" : "truncate";
116+
const auto float32_e8m0_float32 = float32_val_to_e8m0(u.f, description);
117+
float expected_value = expect_round_up ? pow(2, exponent + 1 - 127) : pow(2, exponent - 127);
118+
if (float32_e8m0_float32 != expected_value) {
119+
std::cout << "MISMATCH: expected " << expected_value << ", got " << float32_e8m0_float32 << std::endl;
120+
std::cout << "exponent " << static_cast<int>(exponent) <<
121+
" grs " << static_cast<int>(grs) <<
122+
" res " << float32_e8m0_float32 <<
123+
" pow2 " << pow(2, exponent - 127) <<
124+
" pow2+1 " << pow(2, exponent + 1 - 127) <<
125+
" round_up " << expect_round_up << std::endl << std::endl;
126+
}
127+
}
69128
}
70129

71130
int main() {
72131
std::cout << "start e8m0 test" << std::endl;
73132
std::cout << std::endl;
74133

75134
// test directly instantiang e8m0 from an integer, and the subsequent cast to float32
76-
std::cout << "=== int -> e8m0 -> float32 ===" << std::endl << std::endl;
135+
std::cout << "=== test int -> e8m0 -> float32 ===" << std::endl << std::endl;
77136

78137
int_exponent_to_e8m0(-127, "min_representable");
79138
int_exponent_to_e8m0(-2, "neg_two");
@@ -90,35 +149,73 @@ int main() {
90149

91150
// test casting float32 to e8m0 and then back to float32
92151
// the cast back is just for convenience of interpretation
93-
std::cout << "=== float32 -> e8m0 -> float32 ===" << std::endl << std::endl;
152+
std::cout << "=== test float32 -> e8m0 -> float32 ===" << std::endl << std::endl;
153+
std::cout << "===== manual test cases =====" << std::endl << std::endl;
94154

95155
// min representable value with e8m0 exponent
96156
// 2 ** -127 = 5.877472e-39
97-
float32_val_to_e8m0(5.877472e-39, "2**-127");
157+
float32_val_to_e8m0(5.877472e-39, "2**-127", true);
98158

99159
// max denormal value
100-
float32_val_to_e8m0(1.1754942e-38, "max_denormal");
160+
float32_val_to_e8m0(1.1754942e-38, "max_denormal", true);
101161

102162
// min normal
103-
float32_val_to_e8m0(1.17549435e-38, "min_normal");
163+
float32_val_to_e8m0(1.17549435e-38, "min_normal", true);
104164

105165
// basic cases
106-
float32_val_to_e8m0(0.25, "0.25");
107-
float32_val_to_e8m0(0.5, "0.5");
108-
float32_val_to_e8m0(1.0, "1.0");
109-
float32_val_to_e8m0(2.0, "2.0");
110-
float32_val_to_e8m0(4.0, "4.0");
166+
float32_val_to_e8m0(0.25, "0.25", true);
167+
float32_val_to_e8m0(0.5, "0.5", true);
168+
float32_val_to_e8m0(1.0, "1.0", true);
169+
float32_val_to_e8m0(2.0, "2.0", true);
170+
float32_val_to_e8m0(4.0, "4.0", true);
111171

112172
// max normal
113-
float32_val_to_e8m0(3.4028235e38, "max_normal");
173+
float32_val_to_e8m0(3.4028235e38, "max_normal", true);
114174

115175
// test rounding (asking for RNE in the test case)
116176
// this seems to always round up to the larger power of two at the midpoint
117-
float32_val_to_e8m0(6.0, "");
118-
float32_val_to_e8m0(3.0, "");
119-
float32_val_to_e8m0(1.5, "");
120-
float32_val_to_e8m0(0.75, "");
121-
float32_val_to_e8m0(0.375, "");
177+
float32_val_to_e8m0(6.0, "", true);
178+
float32_val_to_e8m0(3.0, "", true);
179+
float32_val_to_e8m0(1.5, "", true);
180+
float32_val_to_e8m0(0.75, "", true);
181+
float32_val_to_e8m0(0.375, "", true);
182+
183+
std::cout << "===== sweep =====" << std::endl << std::endl;
184+
185+
// rules of RNE for general floating point rounding:
186+
// LSB - last bit we'll keep
187+
// GRS - guard, round, and sticky bits
188+
//
189+
// if G == 0:
190+
// round down (truncate)
191+
// else: // G == 1
192+
// if (R == 1) or (S == 1):
193+
// round up
194+
// else:
195+
// if LSB == 1:
196+
// round up (to make LSB even)
197+
// else:
198+
// round down
199+
//
200+
// for e8m0, LSB is the implied mantissa bit, so it equals to 0 for denormals, and to 1 for normals
201+
//
202+
// if G == 0:
203+
// round down (truncate)
204+
// else: // G == 1
205+
// if (R == 1) or (S == 1):
206+
// round up
207+
// else:
208+
// if normal_number:
209+
// round up (to next power of two)
210+
// else: // denormal number
211+
// round down (to zero)
212+
213+
// Now, let's test if the LLVM e8m0 semantics match the logic above
214+
// for (uint8_t exponent = 0; exponent < 256; exponent++) {
215+
for (uint16_t exponent_large = 0; exponent_large <= 255; exponent_large++) {
216+
uint8_t exponent_small = exponent_large;
217+
fp32_e8m0_fp32_check_all_grs(exponent_small);
218+
}
122219

123220
std::cout << "end e8m0 test" << std::endl;
124221
return 0;

llvm_e8m0/results.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
start e8m0 test
22

3-
=== int -> e8m0 -> float32 ===
3+
=== test int -> e8m0 -> float32 ===
44

55
description: min_representable, unbiased exponent: -127
66
biased exponent: 0
@@ -51,7 +51,9 @@ description: nan, unbiased exponent: 128
5151
expected cast result: 3.40282e+38
5252

5353

54-
=== float32 -> e8m0 -> float32 ===
54+
=== test float32 -> e8m0 -> float32 ===
55+
56+
===== manual test cases =====
5557

5658
description: 2**-127, float32 val: 5.87747e-39
5759
fp32_org bits: 00000000010000000000000000000000
@@ -137,4 +139,9 @@ description: , float32 val: 0.375
137139
fp32_new bits: 00111111000000000000000000000000
138140
e8m0 -> float32 cast result: 0.5
139141

142+
===== sweep =====
143+
144+
MISMATCH: expected 1.17549e-38, got 5.87747e-39
145+
exponent 0 grs 5 res 5.87747e-39 pow2 5.87747e-39 pow2+1 1.17549e-38 round_up 1
146+
140147
end e8m0 test

0 commit comments

Comments
 (0)