8
8
import torch
9
9
10
10
from float8_experimental .float8_python_api import addmm_float8_unwrapped
11
- from float8_experimental .float8_tensor import Float8Tensor
11
+ from float8_experimental .float8_tensor import Float8Tensor , re_construct_float8_weight
12
12
from float8_experimental .float8_utils import is_row_major
13
13
from torch .utils ._pytree import tree_map
14
14
@@ -191,9 +191,7 @@ def forward(
191
191
if recompute_float8_weight :
192
192
# This should be set to True when using traditional fsdp to avoid
193
193
# saving the unsharded weight for backwards
194
- ctx .save_for_backward (
195
- x_fp8 , original_weight , weight_scale , weight_amax_buffer
196
- )
194
+ ctx .save_for_backward (x_fp8 , original_weight , weight_scale )
197
195
else :
198
196
# Does this interact properly with activation checkpointing?
199
197
ctx .save_for_backward (x_fp8 , w_fp8 )
@@ -211,19 +209,15 @@ def forward(
211
209
@staticmethod
212
210
def backward (ctx , go_fp8 : torch .Tensor ):
213
211
if ctx .recompute_float8_weight :
214
- x_fp8 , original_weight , weight_scale , weight_amax_buffer = ctx .saved_tensors
215
- w_fp8 = Float8Tensor .to_float8 (
216
- original_weight ,
217
- weight_scale ,
218
- torch .float8_e4m3fn ,
219
- weight_amax_buffer ,
220
- emulate = ctx .emulate ,
212
+ x_fp8 , original_weight , weight_scale = ctx .saved_tensors
213
+ w_fp8 = re_construct_float8_weight (
214
+ original_weight , weight_scale , torch .float8_e4m3fn , emulate = ctx .emulate
221
215
)
222
216
else :
223
217
x_fp8 , w_fp8 = ctx .saved_tensors
224
218
225
219
# calculate dL/dX
226
- go_fp8_reshaped = go_fp8 .view (- 1 , go_fp8 .size (- 1 ))
220
+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8 .size (- 1 ))
227
221
w_fp8_t_c_t = w_fp8 .t ().contiguous ().t ()
228
222
dL_dX = float8_mm_helper (go_fp8_reshaped , w_fp8_t_c_t )
229
223
dL_dX = dL_dX .view (* go_fp8 .shape [:- 1 ], dL_dX .size (- 1 ))
0 commit comments