Skip to content

Commit 539a971

Browse files
authored
[Release-2.1]Add finfo properties for float8 dtypes (#109808)
Add float8 finfo checks to `test_type_info.py` Fixes #109737 Cherry-pick of #109744 into release/2.1 branch Approved by: https://github.com/drisspg (cherry picked from commit cddd0db)
1 parent 9287a0c commit 539a971

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

aten/src/ATen/Dispatch.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
371371
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
372372
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373373

374+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
375+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
377+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
378+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
379+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
380+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381+
382+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
383+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384+
AT_DISPATCH_SWITCH( \
385+
TYPE, \
386+
NAME, \
387+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
388+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389+
374390
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
375391
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
376392
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \

test/test_type_info.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,25 @@ def test_finfo(self):
7272
# Restore the default type to ensure that the test has no side effect
7373
torch.set_default_dtype(initial_default_type)
7474

75+
# Special test case for Float8_E5M2
76+
xinfo = torch.finfo(torch.float8_e5m2)
77+
self.assertEqual(xinfo.bits, 8)
78+
self.assertEqual(xinfo.max, 57344.0)
79+
self.assertEqual(xinfo.min, -57344.0)
80+
self.assertEqual(xinfo.eps, .25)
81+
self.assertEqual(xinfo.tiny, 6.10352e-05)
82+
self.assertEqual(xinfo.resolution, 1.0)
83+
self.assertEqual(xinfo.dtype, "float8_e5m2")
84+
85+
# Special test case for Float8_E4M3FN
86+
xinfo = torch.finfo(torch.float8_e4m3fn)
87+
self.assertEqual(xinfo.bits, 8)
88+
self.assertEqual(xinfo.max, 448.0)
89+
self.assertEqual(xinfo.min, -448.0)
90+
self.assertEqual(xinfo.eps, .125)
91+
self.assertEqual(xinfo.tiny, 0.015625)
92+
self.assertEqual(xinfo.resolution, 1.0)
93+
self.assertEqual(xinfo.dtype, "float8_e4m3fn")
94+
7595
if __name__ == '__main__':
7696
run_tests()

torch/csrc/TypeInfo.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,43 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
113113
}
114114

115115
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
116-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
117-
at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] {
116+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
117+
at::kHalf,
118+
at::ScalarType::BFloat16,
119+
at::ScalarType::Float8_e4m3fn,
120+
at::ScalarType::Float8_e5m2,
121+
self->type,
122+
"epsilon",
123+
[] {
118124
return PyFloat_FromDouble(
119125
std::numeric_limits<
120126
at::scalar_value_type<scalar_t>::type>::epsilon());
121127
});
122128
}
123129

124130
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
125-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
126-
at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
131+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
132+
at::kHalf,
133+
at::ScalarType::BFloat16,
134+
at::ScalarType::Float8_e4m3fn,
135+
at::ScalarType::Float8_e5m2,
136+
self->type,
137+
"max",
138+
[] {
127139
return PyFloat_FromDouble(
128140
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
129141
});
130142
}
131143

132144
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
133-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
134-
at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
145+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
146+
at::kHalf,
147+
at::ScalarType::BFloat16,
148+
at::ScalarType::Float8_e4m3fn,
149+
at::ScalarType::Float8_e5m2,
150+
self->type,
151+
"lowest",
152+
[] {
135153
return PyFloat_FromDouble(
136154
std::numeric_limits<
137155
at::scalar_value_type<scalar_t>::type>::lowest());
@@ -170,8 +188,14 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
170188
}
171189

172190
static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
173-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
174-
at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
191+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
192+
at::kHalf,
193+
at::ScalarType::BFloat16,
194+
at::ScalarType::Float8_e4m3fn,
195+
at::ScalarType::Float8_e5m2,
196+
self->type,
197+
"smallest",
198+
[] {
175199
return PyFloat_FromDouble(
176200
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
177201
});
@@ -183,8 +207,14 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
183207
}
184208

185209
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
186-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
187-
at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
210+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
211+
at::kHalf,
212+
at::ScalarType::BFloat16,
213+
at::ScalarType::Float8_e4m3fn,
214+
at::ScalarType::Float8_e5m2,
215+
self->type,
216+
"digits10",
217+
[] {
188218
return PyFloat_FromDouble(std::pow(
189219
10,
190220
-std::numeric_limits<
@@ -194,9 +224,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
194224

195225
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
196226
auto primary_name = torch::utils::getDtypeNames(self->type).first;
197-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
227+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
198228
at::kHalf,
199229
at::ScalarType::BFloat16,
230+
at::ScalarType::Float8_e4m3fn,
231+
at::ScalarType::Float8_e5m2,
200232
self->type,
201233
"dtype",
202234
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });

0 commit comments

Comments
 (0)