@@ -25,25 +25,34 @@ def forward(
25
25
ctx ,
26
26
x_fp8 ,
27
27
w_fp8 ,
28
+ b_fp8 ,
28
29
fp8_s_out ,
29
30
fp8_s_dL_dX ,
30
31
fp8_s_dL_dW ,
31
32
fp8_s_dL_dY ,
32
33
):
33
- ctx .save_for_backward (x_fp8 , w_fp8 , fp8_s_dL_dX , fp8_s_dL_dW , fp8_s_dL_dY )
34
-
35
- res_bits = torch .ops .aten .mm_float8 (
36
- x_fp8 ._data , x_fp8 ._scale , x_fp8 ._flavor ,
37
- w_fp8 ._data .t (), w_fp8 ._scale , w_fp8 ._flavor ,
38
- fp8_s_out , E4M3 )
34
+ ctx .save_for_backward (
35
+ x_fp8 , w_fp8 , b_fp8 , fp8_s_dL_dX , fp8_s_dL_dW , fp8_s_dL_dY )
36
+ if b_fp8 is not None :
37
+ # TODO add this
38
+ res_bits = torch .ops .aten .addmm_float8 (
39
+ b_fp8 ._data , b_fp8 ._scale , b_fp8 ._flavor ,
40
+ x_fp8 ._data , x_fp8 ._scale , x_fp8 ._flavor ,
41
+ w_fp8 ._data .t (), w_fp8 ._scale , w_fp8 ._flavor ,
42
+ fp8_s_out , E4M3 )
43
+ else :
44
+ res_bits = torch .ops .aten .mm_float8 (
45
+ x_fp8 ._data , x_fp8 ._scale , x_fp8 ._flavor ,
46
+ w_fp8 ._data .t (), w_fp8 ._scale , w_fp8 ._flavor ,
47
+ fp8_s_out , E4M3 )
39
48
40
49
res = Float8Tensor (res_bits , fp8_s_out , E4M3 )
41
50
# scale update would also happen here, for now no-op
42
51
return res
43
52
44
53
@staticmethod
45
54
def backward (ctx , go ):
46
- x_fp8 , w_fp8 , fp8_s_dL_dX , fp8_s_dL_dW , fp8_s_dL_dY = \
55
+ x_fp8 , w_fp8 , b_fp8 , fp8_s_dL_dX , fp8_s_dL_dW , fp8_s_dL_dY = \
47
56
ctx .saved_tensors
48
57
49
58
if not isinstance (go , Float8Tensor ):
@@ -69,7 +78,10 @@ def backward(ctx, go):
69
78
dL_dW_fp8 = Float8Tensor (dL_dW_bits , fp8_s_dL_dW , E5M2 )
70
79
71
80
# scale update would also happen here, for now no-op
72
- return dL_dX_fp8 , dL_dW_fp8 , None , None , None , None
81
+ if b_fp8 is not None :
82
+ return dL_dX_fp8 , dL_dW_fp8 , go_fp8 , None , None , None , None
83
+ else :
84
+ return dL_dX_fp8 , dL_dW_fp8 , None , None , None , None , None
73
85
74
86
75
87
class Float8Linear (torch .nn .Linear ):
@@ -86,6 +98,7 @@ def __init__(self, *args, **kwargs):
86
98
# or PTQ calibration.
87
99
self .register_buffer ('fp8_s_in' , torch .tensor (1.0 ))
88
100
self .register_buffer ('fp8_s_weight' , torch .tensor (1.0 ))
101
+ self .register_buffer ('fp8_s_bias' , torch .tensor (1.0 ))
89
102
self .register_buffer ('fp8_s_out' , torch .tensor (1.0 ))
90
103
self .register_buffer ('fp8_s_dL_dX' , torch .tensor (1.0 ))
91
104
self .register_buffer ('fp8_s_dL_dW' , torch .tensor (1.0 ))
@@ -102,9 +115,13 @@ def forward(self, x):
102
115
# TODO(future): switch to delayed scaling
103
116
self .fp8_s_weight .fill_ (tensor_to_scale (self .weight , E4M3 ))
104
117
w_fp8 = Float8Tensor .from_float32 (self .weight , self .fp8_s_weight , E4M3 )
118
+ maybe_b_fp8 = None
119
+ if self .bias is not None :
120
+ self .fp8_s_bias .fill_ (tensor_to_scale (self .bias , E4M3 ))
121
+ maybe_b_fp8 = Float8Tensor .from_float32 (self .bias , self .fp8_s_bias , E4M3 )
105
122
106
123
y_fp8 = float8_linear_no_bias .apply (
107
- x_fp8 , w_fp8 , self .fp8_s_out , self .fp8_s_dL_dX ,
124
+ x_fp8 , w_fp8 , maybe_b_fp8 , self .fp8_s_out , self .fp8_s_dL_dX ,
108
125
self .fp8_s_dL_dW , self .fp8_s_dL_dY )
109
126
110
127
# For now, hardcode returning Float8Tensor (propagate as much as we can).
@@ -116,7 +133,7 @@ def from_float(cls, mod):
116
133
"""
117
134
Create an nn.Linear with fp8 compute from a regular nn.Linear
118
135
"""
119
- assert mod .bias is None , 'bias support not implemented yet'
120
136
new_mod = cls (mod .in_features , mod .out_features , bias = False )
121
137
new_mod .weight = mod .weight
138
+ new_mod .bias = mod .bias
122
139
return new_mod
0 commit comments