@@ -41,7 +41,7 @@ def forward(
41
41
x_fp8 , w_fp8 , b_fp8 , float8_amax_dL_dX , float8_amax_dL_dW , float8_amax_dL_dY ,
42
42
bw_amax_initialized )
43
43
orig_shape = x_fp8 ._data .shape
44
- x_fp8_data_reshaped = x_fp8 . _data .reshape (- 1 , orig_shape [- 1 ])
44
+ x_fp8_reshaped = x_fp8 .reshape (- 1 , orig_shape [- 1 ])
45
45
is_fw_amax_initialized = torch .any (fw_amax_initialized )
46
46
47
47
if b_fp8 is not None :
@@ -50,29 +50,29 @@ def forward(
50
50
with torch .no_grad ():
51
51
ref_result = torch .addmm (
52
52
b_fp8 .to_original_precision (),
53
- x_fp8 .to_original_precision (). reshape ( - 1 , orig_shape [ - 1 ] ),
54
- w_fp8 .to_original_precision ().t ())
53
+ x_fp8_reshaped .to_original_precision (),
54
+ w_fp8 .t ().to_original_precision ())
55
55
float8_amax_out .fill_ (tensor_to_amax (ref_result ))
56
56
57
57
y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
58
58
res_bits = torch .ops .aten .addmm_float8 (
59
59
b_fp8 ._data , b_fp8 ._scale ,
60
- x_fp8_data_reshaped , x_fp8 ._scale ,
61
- w_fp8 ._data . t (), w_fp8 ._scale ,
60
+ x_fp8_reshaped . _data , x_fp8 ._scale ,
61
+ w_fp8 .t (). _data , w_fp8 ._scale ,
62
62
float8_amax_out , y_scale , torch .float8_e4m3fn )
63
63
else :
64
64
if not is_fw_amax_initialized :
65
65
# calculate reference amax of output
66
66
with torch .no_grad ():
67
67
ref_result = torch .mm (
68
- x_fp8 .to_original_precision (). reshape ( - 1 , orig_shape [ - 1 ] ),
69
- w_fp8 .to_original_precision ().t ())
68
+ x_fp8_reshaped .to_original_precision (),
69
+ w_fp8 .t ().to_original_precision ())
70
70
float8_amax_out .fill_ (tensor_to_amax (ref_result ))
71
71
72
72
y_scale = amax_to_scale (float8_amax_out , torch .float8_e4m3fn )
73
73
res_bits = torch .ops .aten .mm_float8 (
74
- x_fp8_data_reshaped , x_fp8 ._scale ,
75
- w_fp8 ._data . t (), w_fp8 ._scale ,
74
+ x_fp8_reshaped . _data , x_fp8 ._scale ,
75
+ w_fp8 .t (). _data , w_fp8 ._scale ,
76
76
float8_amax_out , y_scale , torch .float8_e4m3fn )
77
77
res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [- 1 ])
78
78
@@ -101,39 +101,39 @@ def backward(ctx, go):
101
101
go_fp8 = go
102
102
103
103
go_fp8_orig_shape = go_fp8 ._data .shape
104
- go_fp8_data_reshaped = go_fp8 . _data .reshape (- 1 , go_fp8_orig_shape [- 1 ])
104
+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8_orig_shape [- 1 ])
105
105
106
106
if not is_bw_amax_initialized :
107
107
# calculate reference amax of output
108
108
with torch .no_grad ():
109
109
dL_dX_ref = torch .mm (
110
- go_fp8 .to_original_precision (). reshape ( - 1 , go_fp8_orig_shape [ - 1 ] ),
110
+ go_fp8_reshaped .to_original_precision (),
111
111
w_fp8 .to_original_precision ())
112
112
float8_amax_dL_dX .fill_ (tensor_to_amax (dL_dX_ref ))
113
113
114
114
dL_dX_scale = amax_to_scale (float8_amax_dL_dX , torch .float8_e5m2 )
115
115
dL_dX_bits = torch .ops .aten .mm_float8 (
116
- go_fp8_data_reshaped , go_fp8 ._scale ,
116
+ go_fp8_reshaped . _data , go_fp8 ._scale ,
117
117
w_fp8 ._data , w_fp8 ._scale ,
118
118
float8_amax_dL_dX , dL_dX_scale , torch .float8_e5m2 )
119
119
dL_dX_bits = dL_dX_bits .reshape (* go_fp8_orig_shape [:- 1 ], dL_dX_bits .shape [- 1 ])
120
120
dL_dX_fp8 = Float8Tensor (dL_dX_bits , dL_dX_scale , go_fp8 ._orig_dtype )
121
121
122
122
x_fp8_orig_shape = x_fp8 ._data .shape
123
- x_fp8_data_reshaped = x_fp8 . _data .reshape (- 1 , x_fp8_orig_shape [- 1 ])
123
+ x_fp8_reshaped = x_fp8 .reshape (- 1 , x_fp8_orig_shape [- 1 ])
124
124
125
125
if not is_bw_amax_initialized :
126
126
# calculate reference amax of output
127
127
with torch .no_grad ():
128
128
dL_dW_ref = torch .mm (
129
- x_fp8 . to_original_precision ().reshape ( - 1 , x_fp8_orig_shape [ - 1 ]). t (),
130
- go_fp8 .to_original_precision (). reshape ( - 1 , go_fp8_orig_shape [ - 1 ] )).t ()
129
+ x_fp8_reshaped . t ().to_original_precision (),
130
+ go_fp8_reshaped .to_original_precision ()).t ()
131
131
float8_amax_dL_dW .fill_ (tensor_to_amax (dL_dW_ref ))
132
132
133
133
dL_dW_scale = amax_to_scale (float8_amax_dL_dW , torch .float8_e5m2 )
134
134
dL_dW_bits = torch .ops .aten .mm_float8 (
135
- x_fp8_data_reshaped .t (), x_fp8 ._scale ,
136
- go_fp8_data_reshaped , go_fp8 ._scale ,
135
+ x_fp8_reshaped .t (). _data , x_fp8 ._scale ,
136
+ go_fp8_reshaped . _data , go_fp8 ._scale ,
137
137
float8_amax_dL_dW , dL_dW_scale , torch .float8_e5m2 ).t ()
138
138
dL_dW_fp8 = Float8Tensor (dL_dW_bits , dL_dW_scale , go_fp8 ._orig_dtype )
139
139
0 commit comments