@@ -12,6 +12,12 @@ using namespace llvm;
12
12
// PR adding this functionality, for context: https://github.com/llvm/llvm-project/pull/107127
13
13
// this isn't in released LLVM versions, so need to build from source to test it out
14
14
15
+ // access bits of a float
16
+ union float_or_uint32_t {
17
+ float f;
18
+ uint32_t i;
19
+ };
20
+
15
21
16
22
void int_exponent_to_e8m0 (int unbiased_exponent, const std::string& description) {
17
23
std::cout << " description: " << description << " , unbiased exponent: " << unbiased_exponent << std::endl;
@@ -35,45 +41,98 @@ void int_exponent_to_e8m0(int unbiased_exponent, const std::string& description)
35
41
std::cout << std::endl;
36
42
}
37
43
38
- void float32_val_to_e8m0 (float val, const std::string& description) {
44
+ float float32_val_to_e8m0 (
45
+ float val,
46
+ const std::string& description,
47
+ bool debug=false
48
+ ) {
49
+
39
50
APFloat val_float32 (val); // 32-bit float
40
- std::cout << " description: " << description << " , float32 val: " << val_float32.convertToFloat () << std::endl;
51
+ if (debug) {
52
+ std::cout << " description: " << description << " , float32 val: " << val_float32.convertToFloat () << std::endl;
53
+ }
41
54
42
55
APFloat::Semantics Sem = APFloat::S_Float8E8M0FNU;
43
56
const llvm::fltSemantics &S = APFloat::EnumToSemantics (Sem);
44
57
45
58
uint64_t raw_bits_fp32 = val_float32.bitcastToAPInt ().getZExtValue ();
46
- std::cout << " fp32_org bits: " << std::bitset<32 >(raw_bits_fp32) << " \n " ;
59
+ if (debug) {
60
+ std::cout << " fp32_org bits: " << std::bitset<32 >(raw_bits_fp32) << " \n " ;
61
+ }
47
62
48
63
bool losesInfo;
49
64
val_float32.convert (S, APFloat::rmNearestTiesToEven, &losesInfo);
50
65
51
- // print the raw bits
52
- uint64_t raw_bits_e8m0 = val_float32.bitcastToAPInt ().getZExtValue ();
53
- std::cout << " e8m0 bits: " << std::bitset<8 >(raw_bits_e8m0) << " \n " ;
66
+ if (debug) {
67
+ // print the raw bits
68
+ uint64_t raw_bits_e8m0 = val_float32.bitcastToAPInt ().getZExtValue ();
69
+ std::cout << " e8m0 bits: " << std::bitset<8 >(raw_bits_e8m0) << " \n " ;
70
+ }
54
71
55
72
float val_float32_e8m0_float32 = val_float32.convertToFloat ();
56
73
57
- // access bits of a float
58
- union {
59
- float f;
60
- uint32_t i;
61
- } u;
62
- u.f = val_float32_e8m0_float32;
63
- // uint64_t raw_bits_e8m0_fp32 = val_float32_e8m0_float32.bitcastToAPInt().getZExtValue();
64
- std::cout << " fp32_new bits: " << std::bitset<32 >(u.i ) << " \n " ;
74
+ if (debug) {
75
+ // access bits of a float
76
+ union float_or_uint32_t u;
77
+ u.f = val_float32_e8m0_float32;
78
+ std::cout << " fp32_new bits: " << std::bitset<32 >(u.i ) << " \n " ;
65
79
66
- std::cout << " e8m0 -> float32 cast result: " << val_float32_e8m0_float32 << std::endl;
67
- std::cout << std::endl;
80
+ std::cout << " e8m0 -> float32 cast result: " << val_float32_e8m0_float32 << std::endl;
81
+ std::cout << std::endl;
82
+ }
83
+
84
+ return val_float32_e8m0_float32;
85
+ }
68
86
87
+ void fp32_e8m0_fp32_check_all_grs (const uint8_t exponent) {
88
+ union float_or_uint32_t u;
89
+ // iterate through the 8 possible combinations of GRS
90
+ for (uint8_t grs = 0 ; grs < 8 ; grs++) {
91
+ // create a float32 number via bit manipulation
92
+ u.i = (0b0 << 31 ) | (exponent << 23 ) | (grs << (23 - 3 ));
93
+ // TODO NaNs
94
+ bool expect_round_up;
95
+ uint8_t g = grs >> 2 ;
96
+ uint8_t r = (grs & 0b010 ) >> 1 ;
97
+ uint8_t s = grs & 0b1 ;
98
+
99
+ if (g == 0 ) {
100
+ expect_round_up = false ;
101
+ } else {
102
+ if ((r == 1 ) | (s == 1 )) {
103
+ expect_round_up = true ;
104
+ } else {
105
+ if (exponent > 0 ) {
106
+ // normal, round up
107
+ expect_round_up = true ;
108
+ } else {
109
+ // denormal, round down
110
+ expect_round_up = false ;
111
+ }
112
+ }
113
+ }
114
+
115
+ const auto description = expect_round_up ? " round up" : " truncate" ;
116
+ const auto float32_e8m0_float32 = float32_val_to_e8m0 (u.f , description);
117
+ float expected_value = expect_round_up ? pow (2 , exponent + 1 - 127 ) : pow (2 , exponent - 127 );
118
+ if (float32_e8m0_float32 != expected_value) {
119
+ std::cout << " MISMATCH: expected " << expected_value << " , got " << float32_e8m0_float32 << std::endl;
120
+ std::cout << " exponent " << static_cast <int >(exponent) <<
121
+ " grs " << static_cast <int >(grs) <<
122
+ " res " << float32_e8m0_float32 <<
123
+ " pow2 " << pow (2 , exponent - 127 ) <<
124
+ " pow2+1 " << pow (2 , exponent + 1 - 127 ) <<
125
+ " round_up " << expect_round_up << std::endl << std::endl;
126
+ }
127
+ }
69
128
}
70
129
71
130
int main () {
72
131
std::cout << " start e8m0 test" << std::endl;
73
132
std::cout << std::endl;
74
133
75
134
// test directly instantiang e8m0 from an integer, and the subsequent cast to float32
76
- std::cout << " === int -> e8m0 -> float32 ===" << std::endl << std::endl;
135
+ std::cout << " === test int -> e8m0 -> float32 ===" << std::endl << std::endl;
77
136
78
137
int_exponent_to_e8m0 (-127 , " min_representable" );
79
138
int_exponent_to_e8m0 (-2 , " neg_two" );
@@ -90,35 +149,73 @@ int main() {
90
149
91
150
// test casting float32 to e8m0 and then back to float32
92
151
// the cast back is just for convenience of interpretation
93
- std::cout << " === float32 -> e8m0 -> float32 ===" << std::endl << std::endl;
152
+ std::cout << " === test float32 -> e8m0 -> float32 ===" << std::endl << std::endl;
153
+ std::cout << " ===== manual test cases =====" << std::endl << std::endl;
94
154
95
155
// min representable value with e8m0 exponent
96
156
// 2 ** -127 = 5.877472e-39
97
- float32_val_to_e8m0 (5.877472e-39 , " 2**-127" );
157
+ float32_val_to_e8m0 (5.877472e-39 , " 2**-127" , true );
98
158
99
159
// max denormal value
100
- float32_val_to_e8m0 (1.1754942e-38 , " max_denormal" );
160
+ float32_val_to_e8m0 (1.1754942e-38 , " max_denormal" , true );
101
161
102
162
// min normal
103
- float32_val_to_e8m0 (1.17549435e-38 , " min_normal" );
163
+ float32_val_to_e8m0 (1.17549435e-38 , " min_normal" , true );
104
164
105
165
// basic cases
106
- float32_val_to_e8m0 (0.25 , " 0.25" );
107
- float32_val_to_e8m0 (0.5 , " 0.5" );
108
- float32_val_to_e8m0 (1.0 , " 1.0" );
109
- float32_val_to_e8m0 (2.0 , " 2.0" );
110
- float32_val_to_e8m0 (4.0 , " 4.0" );
166
+ float32_val_to_e8m0 (0.25 , " 0.25" , true );
167
+ float32_val_to_e8m0 (0.5 , " 0.5" , true );
168
+ float32_val_to_e8m0 (1.0 , " 1.0" , true );
169
+ float32_val_to_e8m0 (2.0 , " 2.0" , true );
170
+ float32_val_to_e8m0 (4.0 , " 4.0" , true );
111
171
112
172
// max normal
113
- float32_val_to_e8m0 (3.4028235e38 , " max_normal" );
173
+ float32_val_to_e8m0 (3.4028235e38 , " max_normal" , true );
114
174
115
175
// test rounding (asking for RNE in the test case)
116
176
// this seems to always round up to the larger power of two at the midpoint
117
- float32_val_to_e8m0 (6.0 , " " );
118
- float32_val_to_e8m0 (3.0 , " " );
119
- float32_val_to_e8m0 (1.5 , " " );
120
- float32_val_to_e8m0 (0.75 , " " );
121
- float32_val_to_e8m0 (0.375 , " " );
177
+ float32_val_to_e8m0 (6.0 , " " , true );
178
+ float32_val_to_e8m0 (3.0 , " " , true );
179
+ float32_val_to_e8m0 (1.5 , " " , true );
180
+ float32_val_to_e8m0 (0.75 , " " , true );
181
+ float32_val_to_e8m0 (0.375 , " " , true );
182
+
183
+ std::cout << " ===== sweep =====" << std::endl << std::endl;
184
+
185
+ // rules of RNE for general floating point rounding:
186
+ // LSB - last bit we'll keep
187
+ // GRS - guard, round, and sticky bits
188
+ //
189
+ // if G == 0:
190
+ // round down (truncate)
191
+ // else: // G == 1
192
+ // if (R == 1) or (S == 1):
193
+ // round up
194
+ // else:
195
+ // if LSB == 1:
196
+ // round up (to make LSB even)
197
+ // else:
198
+ // round down
199
+ //
200
+ // for e8m0, LSB is the implied mantissa bit, so it equals to 0 for denormals, and to 1 for normals
201
+ //
202
+ // if G == 0:
203
+ // round down (truncate)
204
+ // else: // G == 1
205
+ // if (R == 1) or (S == 1):
206
+ // round up
207
+ // else:
208
+ // if normal_number:
209
+ // round up (to next power of two)
210
+ // else: // denormal number
211
+ // round down (to zero)
212
+
213
+ // Now, let's test if the LLVM e8m0 semantics match the logic above
214
+ // for (uint8_t exponent = 0; exponent < 256; exponent++) {
215
+ for (uint16_t exponent_large = 0 ; exponent_large <= 255 ; exponent_large++) {
216
+ uint8_t exponent_small = exponent_large;
217
+ fp32_e8m0_fp32_check_all_grs (exponent_small);
218
+ }
122
219
123
220
std::cout << " end e8m0 test" << std::endl;
124
221
return 0 ;
0 commit comments