|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
| 6 | +import copy |
6 | 7 | import random
|
7 | 8 | import unittest
|
8 | 9 |
|
|
11 | 12 | import torch
|
12 | 13 | import torch.nn as nn
|
13 | 14 | from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
|
| 15 | +from float8_experimental.float8_tensor import Float8Tensor |
| 16 | + |
| 17 | +from torch._dynamo.test_case import TestCase as DynamoTestCase |
| 18 | +from torch._dynamo.testing import CompileCounterWithBackend |
14 | 19 |
|
15 | 20 | # Setting to unblock for calling contiguous in backwards
|
16 | 21 | is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
|
@@ -76,5 +81,87 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
|
76 | 81 | _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)
|
77 | 82 |
|
78 | 83 |
|
| 84 | +class TestGraphBreaks(DynamoTestCase): |
| 85 | + class MockLinear(torch.nn.Module): |
| 86 | + def __init__(self, graph_break: bool): |
| 87 | + super().__init__() |
| 88 | + self.register_buffer("fp8_amax_x", torch.tensor(1.0)) |
| 89 | + self.register_buffer("fp8_scale_x", torch.tensor(1.0)) |
| 90 | + self.graph_break = graph_break |
| 91 | + |
| 92 | + def forward(self, x): |
| 93 | + x_fp8 = Float8Tensor.to_float8( |
| 94 | + x, |
| 95 | + self.fp8_scale_x, |
| 96 | + torch.float8_e4m3fn, |
| 97 | + self.fp8_amax_x, |
| 98 | + emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False |
| 99 | + ) |
| 100 | + if self.graph_break: |
| 101 | + torch._dynamo.graph_break() |
| 102 | + x_hp = x_fp8.to_original_precision() |
| 103 | + return x_hp |
| 104 | + return x_fp8 |
| 105 | + |
| 106 | + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") |
| 107 | + def test_float8_with_graph_break_in_the_middle(self): |
| 108 | + """Test that having Float8Tensor object at the boundary of a subgraph""" |
| 109 | + cnts = CompileCounterWithBackend("inductor") |
| 110 | + mod = self.MockLinear(graph_break=True).cuda() |
| 111 | + compiled_mod = copy.deepcopy(mod) |
| 112 | + compiled_mod = torch.compile(compiled_mod, backend=cnts) |
| 113 | + x = torch.randn(16, 16, device="cuda") |
| 114 | + y_eager = mod(x) |
| 115 | + y_compiled = compiled_mod(x) |
| 116 | + self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") |
| 117 | + torch.testing.assert_close(y_eager, y_compiled) |
| 118 | + |
| 119 | + def test_float8_graph_input(self): |
| 120 | + """Test that having Float8Tensor object as a graph input""" |
| 121 | + |
| 122 | + def to_float(x): |
| 123 | + return x.to_original_precision() |
| 124 | + |
| 125 | + cnts = CompileCounterWithBackend("inductor") |
| 126 | + mod = self.MockLinear(graph_break=False).cuda() |
| 127 | + x = torch.randn(2, 2, device="cuda") |
| 128 | + compiled_to_float = torch.compile(to_float, backend=cnts) |
| 129 | + y = mod(x) |
| 130 | + y2_eager = to_float(y) |
| 131 | + y2_compiled = compiled_to_float(y) |
| 132 | + self.assertEqual( |
| 133 | + cnts.frame_count, |
| 134 | + 1, |
| 135 | + "to_float was not compiled into 1 frame and likely encountered a skip!", |
| 136 | + ) |
| 137 | + torch.testing.assert_close(y2_eager, y2_compiled) |
| 138 | + |
| 139 | + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") |
| 140 | + def test_float8_graph_output(self): |
| 141 | + """Test that having Float8Tensor object as a graph output works""" |
| 142 | + cnts = CompileCounterWithBackend("inductor") |
| 143 | + mod = self.MockLinear(graph_break=False).cuda() |
| 144 | + compiled_mod = torch.compile(mod, backend=cnts) |
| 145 | + x = torch.randn(16, 16, device="cuda") |
| 146 | + y_compiled = compiled_mod(x) |
| 147 | + |
| 148 | + self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") |
| 149 | + tensors, ctx = y_compiled.__tensor_flatten__() |
| 150 | + for tensor in tensors: |
| 151 | + assert not isinstance( |
| 152 | + getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor |
| 153 | + ), "Float8Tensor should not contain any FakeTensors!" |
| 154 | + assert isinstance( |
| 155 | + y_compiled._orig_dtype, torch.dtype |
| 156 | + ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( |
| 157 | + type(y_compiled._orig_dtype) |
| 158 | + ) |
| 159 | + assert isinstance( |
| 160 | + y_compiled._emulate, bool |
| 161 | + ), "Float8Tensor._emulate should be a bool but got {}".format( |
| 162 | + type(y_compiled._emulate) |
| 163 | + ) |
| 164 | + |
| 165 | + |
79 | 166 | if __name__ == "__main__":
|
80 | 167 | pytest.main([__file__])
|
0 commit comments