Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add tests for Float8Tensor at graph boundaries #196

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import random
import unittest

Expand All @@ -11,6 +12,10 @@
import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
from float8_experimental.float8_tensor import Float8Tensor

from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

# Setting to unblock for calling contiguous in backwards
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Expand Down Expand Up @@ -76,5 +81,87 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)


class TestGraphBreaks(DynamoTestCase):
class MockLinear(torch.nn.Module):
def __init__(self, graph_break: bool):
super().__init__()
self.register_buffer("fp8_amax_x", torch.tensor(1.0))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
self.graph_break = graph_break

def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
self.fp8_amax_x,
emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False
)
if self.graph_break:
torch._dynamo.graph_break()
x_hp = x_fp8.to_original_precision()
return x_hp
return x_fp8

@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=True).cuda()
compiled_mod = copy.deepcopy(mod)
compiled_mod = torch.compile(compiled_mod, backend=cnts)
x = torch.randn(16, 16, device="cuda")
y_eager = mod(x)
y_compiled = compiled_mod(x)
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
torch.testing.assert_close(y_eager, y_compiled)

def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""

def to_float(x):
return x.to_original_precision()

cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=False).cuda()
x = torch.randn(2, 2, device="cuda")
compiled_to_float = torch.compile(to_float, backend=cnts)
y = mod(x)
y2_eager = to_float(y)
y2_compiled = compiled_to_float(y)
self.assertEqual(
cnts.frame_count,
1,
"to_float was not compiled into 1 frame and likely encountered a skip!",
)
torch.testing.assert_close(y2_eager, y2_compiled)

@pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear")
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=False).cuda()
compiled_mod = torch.compile(mod, backend=cnts)
x = torch.randn(16, 16, device="cuda")
y_compiled = compiled_mod(x)

self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
tensors, ctx = y_compiled.__tensor_flatten__()
for tensor in tensors:
assert not isinstance(
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
), "Float8Tensor should not contain any FakeTensors!"
assert isinstance(
y_compiled._orig_dtype, torch.dtype
), "Float8Tensor._orig_dtype should be a dtype but got {}".format(
type(y_compiled._orig_dtype)
)
assert isinstance(
y_compiled._emulate, bool
), "Float8Tensor._emulate should be a bool but got {}".format(
type(y_compiled._emulate)
)


if __name__ == "__main__":
pytest.main([__file__])