9
9
10
10
import torch
11
11
12
- import float8_aten_api
12
+ from float8_python_api import (
13
+ mm_float8 ,
14
+ addmm_float8 ,
15
+ )
13
16
14
17
from float8_utils import (
15
18
tensor_to_amax ,
@@ -52,11 +55,9 @@ def forward(
52
55
float8_amax_out .fill_ (tensor_to_amax (ref_result ))
53
56
54
57
y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
55
- res_bits = torch .ops .aten .addmm_float8 (
56
- b_fp8 ._data , b_fp8 ._scale ,
57
- x_fp8_reshaped ._data , x_fp8 ._scale ,
58
- w_fp8 .t ()._data , w_fp8 ._scale ,
59
- float8_amax_out , y_scale , torch .float8_e4m3fn )
58
+ res_bits = addmm_float8 (
59
+ b_fp8 , x_fp8_reshaped , w_fp8 .t (), float8_amax_out , y_scale ,
60
+ torch .float8_e4m3fn )
60
61
else :
61
62
if not is_fw_amax_initialized :
62
63
# calculate reference amax of output
@@ -65,10 +66,9 @@ def forward(
65
66
float8_amax_out .fill_ (tensor_to_amax (ref_result ))
66
67
67
68
y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
68
- res_bits = torch .ops .aten .mm_float8 (
69
- x_fp8_reshaped ._data , x_fp8 ._scale ,
70
- w_fp8 .t ()._data , w_fp8 ._scale ,
71
- float8_amax_out , y_scale , torch .float8_e4m3fn )
69
+ res_bits = mm_float8 (
70
+ x_fp8_reshaped , w_fp8 .t (), float8_amax_out , y_scale ,
71
+ torch .float8_e4m3fn )
72
72
res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [- 1 ])
73
73
74
74
res = Float8Tensor (res_bits , y_scale , x_fp8 ._orig_dtype )
@@ -105,10 +105,8 @@ def backward(ctx, go):
105
105
float8_amax_dL_dX .fill_ (tensor_to_amax (dL_dX_ref ))
106
106
107
107
dL_dX_scale = amax_to_scale (float8_amax_dL_dX , torch .float8_e5m2 )
108
- dL_dX_bits = torch .ops .aten .mm_float8 (
109
- go_fp8_reshaped ._data , go_fp8 ._scale ,
110
- w_fp8 ._data , w_fp8 ._scale ,
111
- float8_amax_dL_dX , dL_dX_scale , torch .float8_e5m2 )
108
+ dL_dX_bits = mm_float8 (
109
+ go_fp8_reshaped , w_fp8 , float8_amax_dL_dX , dL_dX_scale , torch .float8_e5m2 )
112
110
dL_dX_bits = dL_dX_bits .reshape (* go_fp8_orig_shape [:- 1 ], dL_dX_bits .shape [- 1 ])
113
111
dL_dX_fp8 = Float8Tensor (dL_dX_bits , dL_dX_scale , go_fp8 ._orig_dtype )
114
112
@@ -122,10 +120,9 @@ def backward(ctx, go):
122
120
float8_amax_dL_dW .fill_ (tensor_to_amax (dL_dW_ref ))
123
121
124
122
dL_dW_scale = amax_to_scale (float8_amax_dL_dW , torch .float8_e5m2 )
125
- dL_dW_bits = torch .ops .aten .mm_float8 (
126
- x_fp8_reshaped .t ()._data , x_fp8 ._scale ,
127
- go_fp8_reshaped ._data , go_fp8 ._scale ,
128
- float8_amax_dL_dW , dL_dW_scale , torch .float8_e5m2 ).t ()
123
+ dL_dW_bits = mm_float8 (
124
+ x_fp8_reshaped .t (), go_fp8_reshaped , float8_amax_dL_dW ,
125
+ dL_dW_scale , torch .float8_e5m2 ).t ()
129
126
dL_dW_fp8 = Float8Tensor (dL_dW_bits , dL_dW_scale , go_fp8 ._orig_dtype )
130
127
131
128
if not is_bw_amax_initialized :
0 commit comments