@@ -33,15 +33,15 @@ def forward(
33
33
x_fp8 ,
34
34
w_fp8 ,
35
35
b_fp8 ,
36
- float8_amax_out ,
37
- float8_amax_dL_dX ,
38
- float8_amax_dL_dW ,
39
- float8_amax_dL_dY ,
36
+ fp8_amax_y ,
37
+ fp8_amax_dL_dX ,
38
+ fp8_amax_dL_dW ,
39
+ fp8_amax_dL_dY ,
40
40
fw_amax_initialized ,
41
41
bw_amax_initialized ,
42
42
):
43
43
ctx .save_for_backward (
44
- x_fp8 , w_fp8 , b_fp8 , float8_amax_dL_dX , float8_amax_dL_dW , float8_amax_dL_dY ,
44
+ x_fp8 , w_fp8 , b_fp8 , fp8_amax_dL_dX , fp8_amax_dL_dW , fp8_amax_dL_dY ,
45
45
bw_amax_initialized )
46
46
orig_shape = x_fp8 ._data .shape
47
47
x_fp8_reshaped = x_fp8 .reshape (- 1 , orig_shape [- 1 ])
@@ -52,22 +52,22 @@ def forward(
52
52
# calculate reference amax of output
53
53
with torch .no_grad ():
54
54
ref_result = torch .addmm (b_fp8 , x_fp8_reshaped , w_fp8 .t ())
55
- float8_amax_out .fill_ (tensor_to_amax (ref_result ))
55
+ fp8_amax_y .fill_ (tensor_to_amax (ref_result ))
56
56
57
- y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
57
+ y_scale = amax_to_scale (fp8_amax_y , torch .float8_e4m3fn )
58
58
res_bits = addmm_float8 (
59
- b_fp8 , x_fp8_reshaped , w_fp8 .t (), float8_amax_out , y_scale ,
59
+ b_fp8 , x_fp8_reshaped , w_fp8 .t (), fp8_amax_y , y_scale ,
60
60
torch .float8_e4m3fn )
61
61
else :
62
62
if not is_fw_amax_initialized :
63
63
# calculate reference amax of output
64
64
with torch .no_grad ():
65
65
ref_result = torch .mm (x_fp8_reshaped , w_fp8 .t ())
66
- float8_amax_out .fill_ (tensor_to_amax (ref_result ))
66
+ fp8_amax_y .fill_ (tensor_to_amax (ref_result ))
67
67
68
- y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
68
+ y_scale = amax_to_scale (fp8_amax_y , torch .float8_e4m3fn )
69
69
res_bits = mm_float8 (
70
- x_fp8_reshaped , w_fp8 .t (), float8_amax_out , y_scale ,
70
+ x_fp8_reshaped , w_fp8 .t (), fp8_amax_y , y_scale ,
71
71
torch .float8_e4m3fn )
72
72
res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [- 1 ])
73
73
@@ -77,18 +77,18 @@ def forward(
77
77
78
78
@staticmethod
79
79
def backward (ctx , go ):
80
- x_fp8 , w_fp8 , b_fp8 , float8_amax_dL_dX , float8_amax_dL_dW , \
81
- float8_amax_dL_dY , bw_amax_initialized = \
80
+ x_fp8 , w_fp8 , b_fp8 , fp8_amax_dL_dX , fp8_amax_dL_dW , \
81
+ fp8_amax_dL_dY , bw_amax_initialized = \
82
82
ctx .saved_tensors
83
83
84
84
is_bw_amax_initialized = torch .any (bw_amax_initialized )
85
85
86
86
if not isinstance (go , Float8Tensor ):
87
87
# TODO(future): switch to windowed delayed scaling
88
88
if not is_bw_amax_initialized :
89
- float8_amax_dL_dY .fill_ (tensor_to_amax (go ))
90
- dL_dY_scale = amax_to_scale (float8_amax_dL_dY , torch .float8_e5m2 )
91
- float8_amax_dL_dY .fill_ (tensor_to_amax (go ))
89
+ fp8_amax_dL_dY .fill_ (tensor_to_amax (go ))
90
+ dL_dY_scale = amax_to_scale (fp8_amax_dL_dY , torch .float8_e5m2 )
91
+ fp8_amax_dL_dY .fill_ (tensor_to_amax (go ))
92
92
go_fp8 = Float8Tensor (
93
93
(go * dL_dY_scale ).to (torch .float8_e5m2 ),
94
94
dL_dY_scale , go .dtype )
@@ -102,11 +102,11 @@ def backward(ctx, go):
102
102
# calculate reference amax of output
103
103
with torch .no_grad ():
104
104
dL_dX_ref = torch .mm (go_fp8_reshaped , w_fp8 )
105
- float8_amax_dL_dX .fill_ (tensor_to_amax (dL_dX_ref ))
105
+ fp8_amax_dL_dX .fill_ (tensor_to_amax (dL_dX_ref ))
106
106
107
- dL_dX_scale = amax_to_scale (float8_amax_dL_dX , torch .float8_e5m2 )
107
+ dL_dX_scale = amax_to_scale (fp8_amax_dL_dX , torch .float8_e5m2 )
108
108
dL_dX_bits = mm_float8 (
109
- go_fp8_reshaped , w_fp8 , float8_amax_dL_dX , dL_dX_scale , torch .float8_e5m2 )
109
+ go_fp8_reshaped , w_fp8 , fp8_amax_dL_dX , dL_dX_scale , torch .float8_e5m2 )
110
110
dL_dX_bits = dL_dX_bits .reshape (* go_fp8_orig_shape [:- 1 ], dL_dX_bits .shape [- 1 ])
111
111
dL_dX_fp8 = Float8Tensor (dL_dX_bits , dL_dX_scale , go_fp8 ._orig_dtype )
112
112
@@ -117,11 +117,11 @@ def backward(ctx, go):
117
117
# calculate reference amax of output
118
118
with torch .no_grad ():
119
119
dL_dW_ref = torch .mm (x_fp8_reshaped .t (), go_fp8_reshaped ).t ()
120
- float8_amax_dL_dW .fill_ (tensor_to_amax (dL_dW_ref ))
120
+ fp8_amax_dL_dW .fill_ (tensor_to_amax (dL_dW_ref ))
121
121
122
- dL_dW_scale = amax_to_scale (float8_amax_dL_dW , torch .float8_e5m2 )
122
+ dL_dW_scale = amax_to_scale (fp8_amax_dL_dW , torch .float8_e5m2 )
123
123
dL_dW_bits = mm_float8 (
124
- x_fp8_reshaped .t (), go_fp8_reshaped , float8_amax_dL_dW ,
124
+ x_fp8_reshaped .t (), go_fp8_reshaped , fp8_amax_dL_dW ,
125
125
dL_dW_scale , torch .float8_e5m2 ).t ()
126
126
dL_dW_fp8 = Float8Tensor (dL_dW_bits , dL_dW_scale , go_fp8 ._orig_dtype )
127
127
@@ -147,13 +147,13 @@ def __init__(self, *args, **kwargs):
147
147
# scaling such as the mechanism described in
148
148
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8,
149
149
# or PTQ calibration.
150
- self .register_buffer ('float8_amax_in ' , torch .tensor (E4M3_MAX_POS ))
151
- self .register_buffer ('float8_amax_weight ' , torch .tensor (E4M3_MAX_POS ))
152
- self .register_buffer ('float8_amax_bias ' , torch .tensor (E4M3_MAX_POS ))
153
- self .register_buffer ('float8_amax_out ' , torch .tensor (E4M3_MAX_POS ))
154
- self .register_buffer ('float8_amax_dL_dX ' , torch .tensor (E5M2_MAX_POS ))
155
- self .register_buffer ('float8_amax_dL_dW ' , torch .tensor (E5M2_MAX_POS ))
156
- self .register_buffer ('float8_amax_dL_dY ' , torch .tensor (E5M2_MAX_POS ))
150
+ self .register_buffer ('fp8_amax_x ' , torch .tensor (E4M3_MAX_POS ))
151
+ self .register_buffer ('fp8_amax_w ' , torch .tensor (E4M3_MAX_POS ))
152
+ self .register_buffer ('fp8_amax_b ' , torch .tensor (E4M3_MAX_POS ))
153
+ self .register_buffer ('fp8_amax_y ' , torch .tensor (E4M3_MAX_POS ))
154
+ self .register_buffer ('fp8_amax_dL_dX ' , torch .tensor (E5M2_MAX_POS ))
155
+ self .register_buffer ('fp8_amax_dL_dW ' , torch .tensor (E5M2_MAX_POS ))
156
+ self .register_buffer ('fp8_amax_dL_dY ' , torch .tensor (E5M2_MAX_POS ))
157
157
self .register_buffer ('fw_amax_initialized' , torch .tensor ([0 ], dtype = torch .uint8 ))
158
158
self .register_buffer ('bw_amax_initialized' , torch .tensor ([0 ], dtype = torch .uint8 ))
159
159
@@ -169,34 +169,34 @@ def forward(self, x):
169
169
170
170
# TODO(future): switch to windowed delayed scaling
171
171
if not is_fw_amax_initialized :
172
- self .float8_amax_in .fill_ (tensor_to_amax (x ))
173
- x_scale = amax_to_scale (self .float8_amax_in , torch .float8_e4m3fn )
174
- self .float8_amax_in .fill_ (tensor_to_amax (x ))
172
+ self .fp8_amax_x .fill_ (tensor_to_amax (x ))
173
+ x_scale = amax_to_scale (self .fp8_amax_x , torch .float8_e4m3fn )
174
+ self .fp8_amax_x .fill_ (tensor_to_amax (x ))
175
175
176
176
x_fp8 = Float8Tensor .to_float8 (x , x_scale , torch .float8_e4m3fn )
177
177
else :
178
178
x_fp8 = x
179
179
180
180
# TODO(future): switch to windowed delayed scaling
181
181
if not is_fw_amax_initialized :
182
- self .float8_amax_weight .fill_ (tensor_to_amax (self .weight ))
183
- w_scale = amax_to_scale (self .float8_amax_weight , torch .float8_e4m3fn )
184
- self .float8_amax_weight .fill_ (tensor_to_amax (self .weight ))
182
+ self .fp8_amax_w .fill_ (tensor_to_amax (self .weight ))
183
+ w_scale = amax_to_scale (self .fp8_amax_w , torch .float8_e4m3fn )
184
+ self .fp8_amax_w .fill_ (tensor_to_amax (self .weight ))
185
185
186
186
w_fp8 = Float8Tensor .to_float8 (self .weight , w_scale , torch .float8_e4m3fn )
187
187
maybe_b_fp8 = None
188
188
if self .bias is not None :
189
189
# TODO(future): switch to windowed delayed scaling
190
190
if not is_fw_amax_initialized :
191
- self .float8_amax_bias .fill_ (tensor_to_amax (self .bias ))
192
- b_scale = amax_to_scale (self .float8_amax_bias , torch .float8_e4m3fn )
193
- self .float8_amax_bias .fill_ (tensor_to_amax (self .bias ))
191
+ self .fp8_amax_b .fill_ (tensor_to_amax (self .bias ))
192
+ b_scale = amax_to_scale (self .fp8_amax_b , torch .float8_e4m3fn )
193
+ self .fp8_amax_b .fill_ (tensor_to_amax (self .bias ))
194
194
195
195
maybe_b_fp8 = Float8Tensor .to_float8 (self .bias , b_scale , torch .float8_e4m3fn )
196
196
197
197
y_fp8 = float8_linear .apply (
198
- x_fp8 , w_fp8 , maybe_b_fp8 , self .float8_amax_out , self .float8_amax_dL_dX ,
199
- self .float8_amax_dL_dW , self .float8_amax_dL_dY , self .fw_amax_initialized ,
198
+ x_fp8 , w_fp8 , maybe_b_fp8 , self .fp8_amax_y , self .fp8_amax_dL_dX ,
199
+ self .fp8_amax_dL_dW , self .fp8_amax_dL_dY , self .fw_amax_initialized ,
200
200
self .bw_amax_initialized )
201
201
202
202
if not is_fw_amax_initialized :
0 commit comments