Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

add numerical test on SAM encoder #5

Merged
merged 1 commit into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions float8_playground/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,20 @@ def forward(
):
ctx.save_for_backward(
x_fp8, w_fp8, b_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)
orig_shape = x_fp8._data.shape
x_fp8_data_reshaped = x_fp8._data.reshape(-1, orig_shape[-1])
if b_fp8 is not None:
res_bits = torch.ops.aten.addmm_float8(
b_fp8._data, b_fp8._scale,
x_fp8._data, x_fp8._scale,
x_fp8_data_reshaped, x_fp8._scale,
w_fp8._data.t(), w_fp8._scale,
fp8_s_out, torch.float8_e4m3fn)
else:
res_bits = torch.ops.aten.mm_float8(
x_fp8._data, x_fp8._scale,
x_fp8_data_reshaped, x_fp8._scale,
w_fp8._data.t(), w_fp8._scale,
fp8_s_out, torch.float8_e4m3fn)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])

res = Float8Tensor(res_bits, fp8_s_out)
# scale update would also happen here, for now no-op
Expand All @@ -62,16 +65,25 @@ def backward(ctx, go):
else:
go_fp8 = go

go_fp8_orig_shape = go_fp8._data.shape
go_fp8_data_reshaped = go_fp8._data.reshape(-1, go_fp8_orig_shape[-1])

dL_dX_bits = torch.ops.aten.mm_float8(
go_fp8._data, go_fp8._scale,
go_fp8_data_reshaped, go_fp8._scale,
w_fp8._data, w_fp8._scale,
fp8_s_dL_dX, torch.float8_e5m2)
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX)

x_fp8_orig_shape = x_fp8._data.shape
x_fp8_data_reshaped = x_fp8._data.reshape(-1, x_fp8_orig_shape[-1])

dL_dW_bits = torch.ops.aten.mm_float8(
x_fp8._data.t(), x_fp8._scale,
go_fp8._data, go_fp8._scale,
x_fp8_data_reshaped.t(), x_fp8._scale,
go_fp8_data_reshaped, go_fp8._scale,
fp8_s_dL_dW, torch.float8_e5m2).t()
# import pdb; pdb.set_trace()
# dL_dW_bits = dL_dW_bits.reshape(*x_fp8_orig_shape[:-1], dL_dW_bits.shape[-1])
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW)

# scale update would also happen here, for now no-op
Expand Down Expand Up @@ -134,3 +146,13 @@ def from_float(cls, mod):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def swap_linear_with_float8_linear(model):
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
if isinstance(child, torch.nn.Linear):
new_child = Float8Linear.from_float(child)
setattr(model, name, new_child)
else:
swap_linear_with_float8_linear(child)
18 changes: 12 additions & 6 deletions float8_playground/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _test_linear_impl(self, x, m_ref):
y_ref = m_ref(x)
y_ref.sum().backward()

self.assertTrue(y_ref.shape == y_fp8.shape)

y_sqnr = compute_error(y_ref, y_fp8)
g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad)

Expand Down Expand Up @@ -67,14 +69,18 @@ def _test_linear_impl(self, x, m_ref):
f"{buffer_name} not filled")

def test_linear_nobias(self):
x = torch.randn(2, 3)
m_ref = nn.Linear(3, 4, bias=False)
self._test_linear_impl(x, m_ref)
x_shapes = ((2, 3), (4, 2, 3), (5, 4, 2, 3))
for x_shape in x_shapes:
x = torch.randn(*x_shape)
m_ref = nn.Linear(3, 4, bias=False)
self._test_linear_impl(x, m_ref)

def test_linear_bias(self):
x = torch.randn(2, 3)
m_ref = nn.Linear(3, 4, bias=True)
self._test_linear_impl(x, m_ref)
x_shapes = ((2, 3), (4, 2, 3), (5, 4, 2, 3))
for x_shape in x_shapes:
x = torch.randn(*x_shape)
m_ref = nn.Linear(3, 4, bias=True)
self._test_linear_impl(x, m_ref)


if __name__ == '__main__':
Expand Down
54 changes: 54 additions & 0 deletions float8_playground/test_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Tests SAM with real weights with float8
# if we want finetuning later, we can use
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb

import copy
import unittest

import torch

from transformers import SamModel

from float8_linear import swap_linear_with_float8_linear
from float8_utils import compute_error

torch.manual_seed(0)

class Float8SAMIntegrationTest(unittest.TestCase):

def test_encoder_fw_bw(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")
# print(model)

# for now just test the encoder to simplify things
encoder_ref = model.vision_encoder
encoder_fp8 = copy.deepcopy(encoder_ref)
swap_linear_with_float8_linear(encoder_fp8)

# an image
data = torch.randn(1, 3, 1024, 1024)

encoder_ref_out = encoder_ref(data)
last_hidden_ref = encoder_ref_out.last_hidden_state
last_hidden_ref.sum().backward()

encoder_fp8_out = encoder_fp8(data)
last_hidden_fp8 = encoder_fp8_out.last_hidden_state
last_hidden_fp8.sum().backward()

hidden_sqnr = compute_error(last_hidden_ref, last_hidden_fp8)
self.assertTrue(hidden_sqnr > 20.0)

ref_name_to_grad = \
{name: param.grad for name, param in encoder_ref.named_parameters()}
for name, param in encoder_fp8.named_parameters():
ref_grad = ref_name_to_grad[name]
cur_grad = param.grad
# For now below is for debugging only, numerical values of
# fp32 baseline vs fp8 for grads are not that close for a lot
# of the layers in this network
sqnr = compute_error(ref_grad, cur_grad)


if __name__ == '__main__':
unittest.main()