Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[wip] enable Float8Tensor as subgraph boundary #166

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
from float8_experimental.float8_tensor import Float8Tensor

# Setting to unblock for calling contiguous in backwards
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Expand Down Expand Up @@ -75,6 +76,87 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
torch._dynamo.reset()
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)

def test_float8_with_graph_break_in_the_middle():
# test that having Float8Tensor object at the boundary of a subgraph
# works

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))

def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x,
emulate=False,
)

# graph break
print('foo')
x_hp = x_fp8.to_original_precision()
return x_hp

m = M().cuda()
m = torch.compile(m)
x = torch.randn(16, 16, device='cuda')
y = m(x)

def test_float8_graph_input():
# test that having Float8Tensor object as a graph input works
# works fine!

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))

def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x,
emulate=False,
)

return x_fp8

def to_float(x):
return x.to_original_precision()

to_float = torch.compile(to_float)

m = M().cuda()
x = torch.randn(2, 2, device='cuda')
y = m(x)
print(1, y)
y2 = to_float(y)
print(2, y2)

def test_float8_graph_output():
# test that having Float8Tensor object as a graph output works
# silently incorrect - `y` has fake tensors!

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))

def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x,
emulate=False,
)

return x_fp8

m = M().cuda()
m = torch.compile(m)
x = torch.randn(16, 16, device='cuda')
y = m(x)
print('y', y)



if __name__ == "__main__":
pytest.main([__file__])