Skip to content

Commit cddd0db

Browse files
malfetpytorchmergebot
authored andcommitted
Add finfo properties for float8 dtypes (#109744)
Add float8 finfo checks to `test_type_info.py` Fixes #109737 Pull Request resolved: #109744 Approved by: https://github.com/drisspg
1 parent e2e9d15 commit cddd0db

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

aten/src/ATen/Dispatch.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
371371
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
372372
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373373

374+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
375+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
377+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
378+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
379+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
380+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381+
382+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
383+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384+
AT_DISPATCH_SWITCH( \
385+
TYPE, \
386+
NAME, \
387+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
388+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389+
374390
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
375391
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
376392
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \

test/test_type_info.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,26 @@ def test_finfo(self):
6868
with set_default_dtype(x.dtype):
6969
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
7070

71+
# Special test case for Float8_E5M2
72+
xinfo = torch.finfo(torch.float8_e5m2)
73+
self.assertEqual(xinfo.bits, 8)
74+
self.assertEqual(xinfo.max, 57344.0)
75+
self.assertEqual(xinfo.min, -57344.0)
76+
self.assertEqual(xinfo.eps, .25)
77+
self.assertEqual(xinfo.tiny, 6.10352e-05)
78+
self.assertEqual(xinfo.resolution, 1.0)
79+
self.assertEqual(xinfo.dtype, "float8_e5m2")
80+
81+
# Special test case for Float8_E4M3FN
82+
xinfo = torch.finfo(torch.float8_e4m3fn)
83+
self.assertEqual(xinfo.bits, 8)
84+
self.assertEqual(xinfo.max, 448.0)
85+
self.assertEqual(xinfo.min, -448.0)
86+
self.assertEqual(xinfo.eps, .125)
87+
self.assertEqual(xinfo.tiny, 0.015625)
88+
self.assertEqual(xinfo.resolution, 1.0)
89+
self.assertEqual(xinfo.dtype, "float8_e4m3fn")
90+
7191
if __name__ == '__main__':
7292
TestCase._default_dtype_check_enabled = True
7393
run_tests()

torch/csrc/TypeInfo.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,43 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
112112
}
113113

114114
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
115-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
116-
at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] {
115+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
116+
at::kHalf,
117+
at::ScalarType::BFloat16,
118+
at::ScalarType::Float8_e4m3fn,
119+
at::ScalarType::Float8_e5m2,
120+
self->type,
121+
"epsilon",
122+
[] {
117123
return PyFloat_FromDouble(
118124
std::numeric_limits<
119125
at::scalar_value_type<scalar_t>::type>::epsilon());
120126
});
121127
}
122128

123129
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
124-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
125-
at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
130+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
131+
at::kHalf,
132+
at::ScalarType::BFloat16,
133+
at::ScalarType::Float8_e4m3fn,
134+
at::ScalarType::Float8_e5m2,
135+
self->type,
136+
"max",
137+
[] {
126138
return PyFloat_FromDouble(
127139
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
128140
});
129141
}
130142

131143
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
132-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
133-
at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
144+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
145+
at::kHalf,
146+
at::ScalarType::BFloat16,
147+
at::ScalarType::Float8_e4m3fn,
148+
at::ScalarType::Float8_e5m2,
149+
self->type,
150+
"lowest",
151+
[] {
134152
return PyFloat_FromDouble(
135153
std::numeric_limits<
136154
at::scalar_value_type<scalar_t>::type>::lowest());
@@ -169,8 +187,14 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
169187
}
170188

171189
static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
172-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
173-
at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
190+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
191+
at::kHalf,
192+
at::ScalarType::BFloat16,
193+
at::ScalarType::Float8_e4m3fn,
194+
at::ScalarType::Float8_e5m2,
195+
self->type,
196+
"smallest",
197+
[] {
174198
return PyFloat_FromDouble(
175199
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
176200
});
@@ -182,8 +206,14 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
182206
}
183207

184208
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
185-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
186-
at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
209+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
210+
at::kHalf,
211+
at::ScalarType::BFloat16,
212+
at::ScalarType::Float8_e4m3fn,
213+
at::ScalarType::Float8_e5m2,
214+
self->type,
215+
"digits10",
216+
[] {
187217
return PyFloat_FromDouble(std::pow(
188218
10,
189219
-std::numeric_limits<
@@ -193,9 +223,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
193223

194224
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
195225
auto primary_name = torch::utils::getDtypeNames(self->type).first;
196-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
226+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
197227
at::kHalf,
198228
at::ScalarType::BFloat16,
229+
at::ScalarType::Float8_e4m3fn,
230+
at::ScalarType::Float8_e5m2,
199231
self->type,
200232
"dtype",
201233
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });

0 commit comments

Comments
 (0)