This repository was archived by the owner on Aug 7, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +2
-10
lines changed Expand file tree Collapse file tree 1 file changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -17,7 +17,6 @@ def mm_float8(
17
17
s1 , # input 1 scale
18
18
m2 , # input 2 data
19
19
s2 , # input 2 scale
20
- # s3, # ouput scale
21
20
amax3 , # amax buffer of output, updated inplace in this function
22
21
dtype3 , # output dtype
23
22
):
@@ -32,10 +31,7 @@ def mm_float8(
32
31
s3 = amax_to_scale (amax3 , dtype3 )
33
32
34
33
m3_fp32_scaled = m3_fp32 * s3
35
- if dtype3 == torch .float8_e4m3fn :
36
- return m3_fp32_scaled .to (torch .float8_e4m3fn )
37
- else :
38
- return m3_fp32_scaled .to (torch .float8_e5m2 )
34
+ return m3_fp32_scaled .to (dtype3 )
39
35
40
36
# TODO naming of these vars is weird
41
37
def addmm_float8 (
@@ -45,7 +41,6 @@ def addmm_float8(
45
41
s1 , # input 1 scale
46
42
m2 , # input 2 data
47
43
s2 , # input 2 scale
48
- # s3, # output scale
49
44
amax3 , # amax buffer of output, updated inplace in this function
50
45
dtype3 , # output dtype
51
46
):
@@ -61,10 +56,7 @@ def addmm_float8(
61
56
s3 = amax_to_scale (amax3 , dtype3 )
62
57
63
58
m3_fp32_scaled = m3_fp32 * s3
64
- if dtype3 == torch .float8_e4m3fn :
65
- return m3_fp32_scaled .to (torch .float8_e4m3fn )
66
- else :
67
- return m3_fp32_scaled .to (torch .float8_e5m2 )
59
+ return m3_fp32_scaled .to (dtype3 )
68
60
69
61
70
62
#
You can’t perform that action at this time.
0 commit comments