Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

enable autocast support for single GPU #15

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions float8_playground/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def __init__(self, *args, **kwargs):

def forward(self, x):
if not isinstance(x, Float8Tensor):
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
x = x.to(torch.get_autocast_gpu_dtype())

# TODO(future): switch to delayed scaling
self.fp8_s_in.fill_(tensor_to_scale(x, torch.float8_e4m3fn))
x_fp8 = Float8Tensor.to_float8(x, self.fp8_s_in, torch.float8_e4m3fn)
Expand Down
22 changes: 20 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def _test_linear_impl(self, x, m_ref):
g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad)

# verify sqnr is reasonable
self.assertTrue(y_sqnr >= 22.0)
self.assertTrue(g_sqnr >= 22.0)
self.assertTrue(y_sqnr >= 18.0)
self.assertTrue(g_sqnr >= 18.0)
if m_ref.bias is not None:
torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad)

Expand Down Expand Up @@ -96,6 +96,24 @@ def test_linear_bias(self):
m_ref = nn.Linear(3, 4, bias=True, device='cuda')
self._test_linear_impl(x, m_ref)

def test_autocast(self):
# for now the support is very simple:
# 1. if autocast is off, output of Float8Linear has _orig_precision set to float
# 2. if autocast is on, output of Float8Linear has _orig_precision set to half

m = nn.Linear(4, 4, device='cuda')
m = Float8Linear.from_float(m)

# autocast off
x = torch.randn(4, 4, device='cuda')
y = m(x)
self.assertTrue(y._orig_dtype == torch.float)

# autocast on
with torch.autocast('cuda'):
y = m(x)
self.assertTrue(y._orig_dtype == torch.half)


if __name__ == '__main__':
unittest.main()