Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9d8d2bc

Browse files
authored
Merge pull request #19 from pytorch-labs/float8_reshape
make Float8Tensor reshape'able and t'able
2 parents 145f31a + f5609fb commit 9d8d2bc

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

float8_playground/float8_linear.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def forward(
4141
x_fp8, w_fp8, b_fp8, float8_amax_dL_dX, float8_amax_dL_dW, float8_amax_dL_dY,
4242
bw_amax_initialized)
4343
orig_shape = x_fp8._data.shape
44-
x_fp8_data_reshaped = x_fp8._data.reshape(-1, orig_shape[-1])
44+
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
4545
is_fw_amax_initialized = torch.any(fw_amax_initialized)
4646

4747
if b_fp8 is not None:
@@ -50,29 +50,29 @@ def forward(
5050
with torch.no_grad():
5151
ref_result = torch.addmm(
5252
b_fp8.to_original_precision(),
53-
x_fp8.to_original_precision().reshape(-1, orig_shape[-1]),
54-
w_fp8.to_original_precision().t())
53+
x_fp8_reshaped.to_original_precision(),
54+
w_fp8.t().to_original_precision())
5555
float8_amax_out.fill_(tensor_to_amax(ref_result))
5656

5757
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
5858
res_bits = torch.ops.aten.addmm_float8(
5959
b_fp8._data, b_fp8._scale,
60-
x_fp8_data_reshaped, x_fp8._scale,
61-
w_fp8._data.t(), w_fp8._scale,
60+
x_fp8_reshaped._data, x_fp8._scale,
61+
w_fp8.t()._data, w_fp8._scale,
6262
float8_amax_out, y_scale, torch.float8_e4m3fn)
6363
else:
6464
if not is_fw_amax_initialized:
6565
# calculate reference amax of output
6666
with torch.no_grad():
6767
ref_result = torch.mm(
68-
x_fp8.to_original_precision().reshape(-1, orig_shape[-1]),
69-
w_fp8.to_original_precision().t())
68+
x_fp8_reshaped.to_original_precision(),
69+
w_fp8.t().to_original_precision())
7070
float8_amax_out.fill_(tensor_to_amax(ref_result))
7171

7272
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
7373
res_bits = torch.ops.aten.mm_float8(
74-
x_fp8_data_reshaped, x_fp8._scale,
75-
w_fp8._data.t(), w_fp8._scale,
74+
x_fp8_reshaped._data, x_fp8._scale,
75+
w_fp8.t()._data, w_fp8._scale,
7676
float8_amax_out, y_scale, torch.float8_e4m3fn)
7777
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
7878

@@ -101,39 +101,39 @@ def backward(ctx, go):
101101
go_fp8 = go
102102

103103
go_fp8_orig_shape = go_fp8._data.shape
104-
go_fp8_data_reshaped = go_fp8._data.reshape(-1, go_fp8_orig_shape[-1])
104+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
105105

106106
if not is_bw_amax_initialized:
107107
# calculate reference amax of output
108108
with torch.no_grad():
109109
dL_dX_ref = torch.mm(
110-
go_fp8.to_original_precision().reshape(-1, go_fp8_orig_shape[-1]),
110+
go_fp8_reshaped.to_original_precision(),
111111
w_fp8.to_original_precision())
112112
float8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))
113113

114114
dL_dX_scale = amax_to_scale(float8_amax_dL_dX, torch.float8_e5m2)
115115
dL_dX_bits = torch.ops.aten.mm_float8(
116-
go_fp8_data_reshaped, go_fp8._scale,
116+
go_fp8_reshaped._data, go_fp8._scale,
117117
w_fp8._data, w_fp8._scale,
118118
float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
119119
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
120120
dL_dX_fp8 = Float8Tensor(dL_dX_bits, dL_dX_scale, go_fp8._orig_dtype)
121121

122122
x_fp8_orig_shape = x_fp8._data.shape
123-
x_fp8_data_reshaped = x_fp8._data.reshape(-1, x_fp8_orig_shape[-1])
123+
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])
124124

125125
if not is_bw_amax_initialized:
126126
# calculate reference amax of output
127127
with torch.no_grad():
128128
dL_dW_ref = torch.mm(
129-
x_fp8.to_original_precision().reshape(-1, x_fp8_orig_shape[-1]).t(),
130-
go_fp8.to_original_precision().reshape(-1, go_fp8_orig_shape[-1])).t()
129+
x_fp8_reshaped.t().to_original_precision(),
130+
go_fp8_reshaped.to_original_precision()).t()
131131
float8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))
132132

133133
dL_dW_scale = amax_to_scale(float8_amax_dL_dW, torch.float8_e5m2)
134134
dL_dW_bits = torch.ops.aten.mm_float8(
135-
x_fp8_data_reshaped.t(), x_fp8._scale,
136-
go_fp8_data_reshaped, go_fp8._scale,
135+
x_fp8_reshaped.t()._data, x_fp8._scale,
136+
go_fp8_reshaped._data, go_fp8._scale,
137137
float8_amax_dL_dW, dL_dW_scale, torch.float8_e5m2).t()
138138
dL_dW_fp8 = Float8Tensor(dL_dW_bits, dL_dW_scale, go_fp8._orig_dtype)
139139

float8_playground/float8_tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def to_float8(cls, tensor, scale, dtype):
8888

8989
@classmethod
9090
def __torch_dispatch__(cls, func, types, args, kwargs=None):
91+
if func is aten.view.default:
92+
orig_tensor, view_args = args
93+
new_tensor = Float8Tensor(
94+
orig_tensor._data.view(*view_args), orig_tensor._scale,
95+
orig_tensor._orig_dtype)
96+
return new_tensor
97+
elif func is aten.t.default:
98+
orig_tensor, = args
99+
new_tensor = Float8Tensor(
100+
orig_tensor._data.t(), orig_tensor._scale,
101+
orig_tensor._orig_dtype)
102+
return new_tensor
103+
91104
# for all ops that get here, fall back to original precision
92105
def unwrap(t):
93106
if isinstance(t, Float8Tensor):

tests/test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ def test_preserves_dtype(self):
4545
x3_hp = x2_lp.to_original_precision()
4646
self.assertTrue(x3_hp.dtype == hp_dtype)
4747

48+
def test_reshape(self):
49+
x1_fp32 = torch.randn(4, 4, device='cuda')
50+
x1_s = tensor_to_scale(x1_fp32, torch.float8_e4m3fn)
51+
x1_fp8 = Float8Tensor.to_float8(x1_fp32, x1_s, torch.float8_e4m3fn)
52+
new_shape = (2, -1)
53+
x2_fp32 = x1_fp32.reshape(*new_shape)
54+
x2_fp8 = x1_fp8.reshape(*new_shape)
55+
self.assertTrue(x2_fp8.shape == x2_fp32.shape)
56+
self.assertTrue(type(x2_fp8) == Float8Tensor)
57+
58+
def test_transpose(self):
59+
x1_fp32 = torch.randn(4, 4, device='cuda')
60+
x1_s = tensor_to_scale(x1_fp32, torch.float8_e4m3fn)
61+
x1_fp8 = Float8Tensor.to_float8(x1_fp32, x1_s, torch.float8_e4m3fn)
62+
x2_fp32 = x1_fp32.t()
63+
x2_fp8 = x1_fp8.t()
64+
self.assertTrue(x2_fp8.shape == x2_fp32.shape)
65+
self.assertTrue(type(x2_fp8) == Float8Tensor)
66+
4867

4968
class Float8LinearUnitTest(unittest.TestCase):
5069
def _test_linear_impl(self, x, m_ref):

0 commit comments

Comments
 (0)