Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3d2a32e

Browse files
committed
[doesn't work yet] enable Float8Tensor as subgraph boundary
Summary: In https://github.com/pytorch/pytorch/pull/114311/files, the signature expected by traceable subclasses changed. This PR updates `Float8Tensor` to the new spec. Note: doesn't work yet, need to debug. This may be blocking composability of Float8Linear + FSDP + torch.compile. Test Plan: ``` pytest test/test_compile.py -s --sw -k graph_break // currently broken // logs: https://gist.github.com/vkuzo/ba98a01a459fb9c966f167d8ecca1780 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 8ed0eb7 commit 3d2a32e

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/test_compile.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
14+
from float8_experimental.float8_tensor import Float8Tensor
1415

1516
# Setting to unblock for calling contiguous in backwards
1617
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
@@ -75,6 +76,32 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
7576
torch._dynamo.reset()
7677
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)
7778

79+
def test_float8_with_graph_break():
80+
# test that having Float8Tensor object at the boundary of a subgraph
81+
# works
82+
83+
class M(torch.nn.Module):
84+
def __init__(self):
85+
super().__init__()
86+
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
87+
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
88+
89+
def forward(self, x):
90+
x_fp8 = Float8Tensor.to_float8(
91+
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x,
92+
emulate=False,
93+
)
94+
95+
# graph break
96+
print('foo')
97+
x_hp = x_fp8.to_original_precision()
98+
return x_hp
99+
100+
m = M()
101+
m = torch.compile(m)
102+
x = torch.randn(16, 16)
103+
y = m(x)
104+
78105

79106
if __name__ == "__main__":
80107
pytest.main([__file__])

0 commit comments

Comments
 (0)