Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 8c7fd2f

Browse files
authored
Merge pull request #15 from pytorch-labs/test_autocast
enable autocast support for single GPU
2 parents af60830 + 6c34ec4 commit 8c7fd2f

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

float8_playground/float8_linear.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def __init__(self, *args, **kwargs):
113113

114114
def forward(self, x):
115115
if not isinstance(x, Float8Tensor):
116+
# Duplicate the autocast logic for F.linear, so that the output
117+
# of our module has the right original precision
118+
if torch.is_autocast_enabled():
119+
# For now, hardcode to GPU's autocast dtype
120+
# if we need CPU support in the future, we can add it
121+
x = x.to(torch.get_autocast_gpu_dtype())
122+
116123
# TODO(future): switch to delayed scaling
117124
self.fp8_s_in.fill_(tensor_to_scale(x, torch.float8_e4m3fn))
118125
x_fp8 = Float8Tensor.to_float8(x, self.fp8_s_in, torch.float8_e4m3fn)

tests/test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def _test_linear_impl(self, x, m_ref):
6060
g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad)
6161

6262
# verify sqnr is reasonable
63-
self.assertTrue(y_sqnr >= 22.0)
64-
self.assertTrue(g_sqnr >= 22.0)
63+
self.assertTrue(y_sqnr >= 18.0)
64+
self.assertTrue(g_sqnr >= 18.0)
6565
if m_ref.bias is not None:
6666
torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad)
6767

@@ -96,6 +96,24 @@ def test_linear_bias(self):
9696
m_ref = nn.Linear(3, 4, bias=True, device='cuda')
9797
self._test_linear_impl(x, m_ref)
9898

99+
def test_autocast(self):
100+
# for now the support is very simple:
101+
# 1. if autocast is off, output of Float8Linear has _orig_precision set to float
102+
# 2. if autocast is on, output of Float8Linear has _orig_precision set to half
103+
104+
m = nn.Linear(4, 4, device='cuda')
105+
m = Float8Linear.from_float(m)
106+
107+
# autocast off
108+
x = torch.randn(4, 4, device='cuda')
109+
y = m(x)
110+
self.assertTrue(y._orig_dtype == torch.float)
111+
112+
# autocast on
113+
with torch.autocast('cuda'):
114+
y = m(x)
115+
self.assertTrue(y._orig_dtype == torch.half)
116+
99117

100118
if __name__ == '__main__':
101119
unittest.main()

0 commit comments

Comments
 (0)