Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 289c122

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add tests for Float8Tensor at graph boundaries (#196)
Summary: This is a copy with some tweaks of: #166 Pull Request resolved: #196 Reviewed By: bdhirsh Differential Revision: D53056829 Pulled By: drisspg fbshipit-source-id: 1856d2e4bb1410161e326795014a56d0f91249e0
1 parent 713d2db commit 289c122

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

test/test_compile.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import copy
67
import random
78
import unittest
89

@@ -11,6 +12,10 @@
1112
import torch
1213
import torch.nn as nn
1314
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
15+
from float8_experimental.float8_tensor import Float8Tensor
16+
17+
from torch._dynamo.test_case import TestCase as DynamoTestCase
18+
from torch._dynamo.testing import CompileCounterWithBackend
1419

1520
# Setting to unblock for calling contiguous in backwards
1621
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
@@ -76,5 +81,87 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
7681
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)
7782

7883

84+
class TestGraphBreaks(DynamoTestCase):
85+
class MockLinear(torch.nn.Module):
86+
def __init__(self, graph_break: bool):
87+
super().__init__()
88+
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
89+
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
90+
self.graph_break = graph_break
91+
92+
def forward(self, x):
93+
x_fp8 = Float8Tensor.to_float8(
94+
x,
95+
self.fp8_scale_x,
96+
torch.float8_e4m3fn,
97+
self.fp8_amax_x,
98+
emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False
99+
)
100+
if self.graph_break:
101+
torch._dynamo.graph_break()
102+
x_hp = x_fp8.to_original_precision()
103+
return x_hp
104+
return x_fp8
105+
106+
@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
107+
def test_float8_with_graph_break_in_the_middle(self):
108+
"""Test that having Float8Tensor object at the boundary of a subgraph"""
109+
cnts = CompileCounterWithBackend("inductor")
110+
mod = self.MockLinear(graph_break=True).cuda()
111+
compiled_mod = copy.deepcopy(mod)
112+
compiled_mod = torch.compile(compiled_mod, backend=cnts)
113+
x = torch.randn(16, 16, device="cuda")
114+
y_eager = mod(x)
115+
y_compiled = compiled_mod(x)
116+
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
117+
torch.testing.assert_close(y_eager, y_compiled)
118+
119+
def test_float8_graph_input(self):
120+
"""Test that having Float8Tensor object as a graph input"""
121+
122+
def to_float(x):
123+
return x.to_original_precision()
124+
125+
cnts = CompileCounterWithBackend("inductor")
126+
mod = self.MockLinear(graph_break=False).cuda()
127+
x = torch.randn(2, 2, device="cuda")
128+
compiled_to_float = torch.compile(to_float, backend=cnts)
129+
y = mod(x)
130+
y2_eager = to_float(y)
131+
y2_compiled = compiled_to_float(y)
132+
self.assertEqual(
133+
cnts.frame_count,
134+
1,
135+
"to_float was not compiled into 1 frame and likely encountered a skip!",
136+
)
137+
torch.testing.assert_close(y2_eager, y2_compiled)
138+
139+
@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
140+
def test_float8_graph_output(self):
141+
"""Test that having Float8Tensor object as a graph output works"""
142+
cnts = CompileCounterWithBackend("inductor")
143+
mod = self.MockLinear(graph_break=False).cuda()
144+
compiled_mod = torch.compile(mod, backend=cnts)
145+
x = torch.randn(16, 16, device="cuda")
146+
y_compiled = compiled_mod(x)
147+
148+
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
149+
tensors, ctx = y_compiled.__tensor_flatten__()
150+
for tensor in tensors:
151+
assert not isinstance(
152+
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
153+
), "Float8Tensor should not contain any FakeTensors!"
154+
assert isinstance(
155+
y_compiled._orig_dtype, torch.dtype
156+
), "Float8Tensor._orig_dtype should be a dtype but got {}".format(
157+
type(y_compiled._orig_dtype)
158+
)
159+
assert isinstance(
160+
y_compiled._emulate, bool
161+
), "Float8Tensor._emulate should be a bool but got {}".format(
162+
type(y_compiled._emulate)
163+
)
164+
165+
79166
if __name__ == "__main__":
80167
pytest.main([__file__])

0 commit comments

Comments
 (0)