Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

make Float8Tensor reshape'able and t'able #19

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions float8_playground/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(
x_fp8, w_fp8, b_fp8, float8_amax_dL_dX, float8_amax_dL_dW, float8_amax_dL_dY,
bw_amax_initialized)
orig_shape = x_fp8._data.shape
x_fp8_data_reshaped = x_fp8._data.reshape(-1, orig_shape[-1])
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
is_fw_amax_initialized = torch.any(fw_amax_initialized)

if b_fp8 is not None:
Expand All @@ -50,29 +50,29 @@ def forward(
with torch.no_grad():
ref_result = torch.addmm(
b_fp8.to_original_precision(),
x_fp8.to_original_precision().reshape(-1, orig_shape[-1]),
w_fp8.to_original_precision().t())
x_fp8_reshaped.to_original_precision(),
w_fp8.t().to_original_precision())
float8_amax_out.fill_(tensor_to_amax(ref_result))

y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
res_bits = torch.ops.aten.addmm_float8(
b_fp8._data, b_fp8._scale,
x_fp8_data_reshaped, x_fp8._scale,
w_fp8._data.t(), w_fp8._scale,
x_fp8_reshaped._data, x_fp8._scale,
w_fp8.t()._data, w_fp8._scale,
float8_amax_out, y_scale, torch.float8_e4m3fn)
else:
if not is_fw_amax_initialized:
# calculate reference amax of output
with torch.no_grad():
ref_result = torch.mm(
x_fp8.to_original_precision().reshape(-1, orig_shape[-1]),
w_fp8.to_original_precision().t())
x_fp8_reshaped.to_original_precision(),
w_fp8.t().to_original_precision())
float8_amax_out.fill_(tensor_to_amax(ref_result))

y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
res_bits = torch.ops.aten.mm_float8(
x_fp8_data_reshaped, x_fp8._scale,
w_fp8._data.t(), w_fp8._scale,
x_fp8_reshaped._data, x_fp8._scale,
w_fp8.t()._data, w_fp8._scale,
float8_amax_out, y_scale, torch.float8_e4m3fn)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])

Expand Down Expand Up @@ -101,39 +101,39 @@ def backward(ctx, go):
go_fp8 = go

go_fp8_orig_shape = go_fp8._data.shape
go_fp8_data_reshaped = go_fp8._data.reshape(-1, go_fp8_orig_shape[-1])
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])

if not is_bw_amax_initialized:
# calculate reference amax of output
with torch.no_grad():
dL_dX_ref = torch.mm(
go_fp8.to_original_precision().reshape(-1, go_fp8_orig_shape[-1]),
go_fp8_reshaped.to_original_precision(),
w_fp8.to_original_precision())
float8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))

dL_dX_scale = amax_to_scale(float8_amax_dL_dX, torch.float8_e5m2)
dL_dX_bits = torch.ops.aten.mm_float8(
go_fp8_data_reshaped, go_fp8._scale,
go_fp8_reshaped._data, go_fp8._scale,
w_fp8._data, w_fp8._scale,
float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
dL_dX_fp8 = Float8Tensor(dL_dX_bits, dL_dX_scale, go_fp8._orig_dtype)

x_fp8_orig_shape = x_fp8._data.shape
x_fp8_data_reshaped = x_fp8._data.reshape(-1, x_fp8_orig_shape[-1])
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])

if not is_bw_amax_initialized:
# calculate reference amax of output
with torch.no_grad():
dL_dW_ref = torch.mm(
x_fp8.to_original_precision().reshape(-1, x_fp8_orig_shape[-1]).t(),
go_fp8.to_original_precision().reshape(-1, go_fp8_orig_shape[-1])).t()
x_fp8_reshaped.t().to_original_precision(),
go_fp8_reshaped.to_original_precision()).t()
float8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))

dL_dW_scale = amax_to_scale(float8_amax_dL_dW, torch.float8_e5m2)
dL_dW_bits = torch.ops.aten.mm_float8(
x_fp8_data_reshaped.t(), x_fp8._scale,
go_fp8_data_reshaped, go_fp8._scale,
x_fp8_reshaped.t()._data, x_fp8._scale,
go_fp8_reshaped._data, go_fp8._scale,
float8_amax_dL_dW, dL_dW_scale, torch.float8_e5m2).t()
dL_dW_fp8 = Float8Tensor(dL_dW_bits, dL_dW_scale, go_fp8._orig_dtype)

Expand Down
13 changes: 13 additions & 0 deletions float8_playground/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def to_float8(cls, tensor, scale, dtype):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func is aten.view.default:
orig_tensor, view_args = args
new_tensor = Float8Tensor(
orig_tensor._data.view(*view_args), orig_tensor._scale,
orig_tensor._orig_dtype)
return new_tensor
elif func is aten.t.default:
orig_tensor, = args
new_tensor = Float8Tensor(
orig_tensor._data.t(), orig_tensor._scale,
orig_tensor._orig_dtype)
return new_tensor

# for all ops that get here, fall back to original precision
def unwrap(t):
if isinstance(t, Float8Tensor):
Expand Down
19 changes: 19 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ def test_preserves_dtype(self):
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

def test_reshape(self):
x1_fp32 = torch.randn(4, 4, device='cuda')
x1_s = tensor_to_scale(x1_fp32, torch.float8_e4m3fn)
x1_fp8 = Float8Tensor.to_float8(x1_fp32, x1_s, torch.float8_e4m3fn)
new_shape = (2, -1)
x2_fp32 = x1_fp32.reshape(*new_shape)
x2_fp8 = x1_fp8.reshape(*new_shape)
self.assertTrue(x2_fp8.shape == x2_fp32.shape)
self.assertTrue(type(x2_fp8) == Float8Tensor)

def test_transpose(self):
x1_fp32 = torch.randn(4, 4, device='cuda')
x1_s = tensor_to_scale(x1_fp32, torch.float8_e4m3fn)
x1_fp8 = Float8Tensor.to_float8(x1_fp32, x1_s, torch.float8_e4m3fn)
x2_fp32 = x1_fp32.t()
x2_fp8 = x1_fp8.t()
self.assertTrue(x2_fp8.shape == x2_fp32.shape)
self.assertTrue(type(x2_fp8) == Float8Tensor)


class Float8LinearUnitTest(unittest.TestCase):
def _test_linear_impl(self, x, m_ref):
Expand Down