Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b30e1ae

Browse files
authored
Merge pull request #17 from pytorch-labs/nits
nit fixes
2 parents ade56e7 + f5f615b commit b30e1ae

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

float8_playground/float8_aten_api.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def mm_float8(
1717
s1, # input 1 scale
1818
m2, # input 2 data
1919
s2, # input 2 scale
20-
# s3, # ouput scale
2120
amax3, # amax buffer of output, updated inplace in this function
2221
dtype3, # output dtype
2322
):
@@ -32,10 +31,7 @@ def mm_float8(
3231
s3 = amax_to_scale(amax3, dtype3)
3332

3433
m3_fp32_scaled = m3_fp32 * s3
35-
if dtype3 == torch.float8_e4m3fn:
36-
return m3_fp32_scaled.to(torch.float8_e4m3fn)
37-
else:
38-
return m3_fp32_scaled.to(torch.float8_e5m2)
34+
return m3_fp32_scaled.to(dtype3)
3935

4036
# TODO naming of these vars is weird
4137
def addmm_float8(
@@ -45,7 +41,6 @@ def addmm_float8(
4541
s1, # input 1 scale
4642
m2, # input 2 data
4743
s2, # input 2 scale
48-
# s3, # output scale
4944
amax3, # amax buffer of output, updated inplace in this function
5045
dtype3, # output dtype
5146
):
@@ -61,10 +56,7 @@ def addmm_float8(
6156
s3 = amax_to_scale(amax3, dtype3)
6257

6358
m3_fp32_scaled = m3_fp32 * s3
64-
if dtype3 == torch.float8_e4m3fn:
65-
return m3_fp32_scaled.to(torch.float8_e4m3fn)
66-
else:
67-
return m3_fp32_scaled.to(torch.float8_e5m2)
59+
return m3_fp32_scaled.to(dtype3)
6860

6961

7062
#

0 commit comments

Comments
 (0)