Skip to content

fix: Add test suite for torch.compile backend #1849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,22 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile-core:
description: "Test the Dynamo torch_compile path"
steps:
- run:
name: Run Dynamo torch_compile core tests
command: |
cd py/torch_tensorrt/dynamo/torch_compile
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
steps:
Expand Down Expand Up @@ -953,6 +969,7 @@ jobs:
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-torch_compile-core
- test-dynamo-fx_ts

package-x86_64-linux:
Expand Down
39 changes: 0 additions & 39 deletions py/torch_tensorrt/dynamo/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,3 @@ def cosine_similarity(gt_tensor, pred_tensor):
res = res.cpu().detach().item()

return res


def same_output_format(trt_output, torch_output):
# For each encountered collection type, ensure the torch and trt outputs agree
# on type and size, checking recursively through all member elements.
if isinstance(trt_output, tuple):
return (
isinstance(torch_output, tuple)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, list):
return (
isinstance(torch_output, list)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, dict):
return (
isinstance(torch_output, dict)
and (len(trt_output) == len(torch_output))
and (trt_output.keys() == torch_output.keys())
and all(
same_output_format(trt_output[key], torch_output[key])
for key in trt_output.keys()
)
)
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
raise AssertionError(
"Unsupported output type 'set' encountered in output format check."
)
else:
return type(trt_output) is type(torch_output)
57 changes: 57 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
from utils import same_output_format
import torch_tensorrt
import unittest
import torch


class TestPrepareDevice(unittest.TestCase):
def test_prepare_cuda_device(self):
gpu_id = 0
device = torch.device(f"cuda:{gpu_id}")
prepared_device = prepare_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)

def test_prepare_trt_device(self):
gpu_id = 4
device = torch_tensorrt.Device(gpu_id=gpu_id)
prepared_device = prepare_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)


class TestPrepareInputs(unittest.TestCase):
def test_prepare_single_tensor_input(self):
inputs = [torch.ones((4, 4))]
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)

def test_prepare_trt_input(self):
inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)]
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)

def test_prepare_mixed_type_compound_tensor_input(self):
inputs = {
"first": [
torch.ones((4, 4)),
torch_tensorrt.Input(shape=(4, 3), dtype=torch.float),
],
"second": (
torch.rand((5, 1)),
(torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))),
),
}
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from functools import partial
from utils import fx_dynamo_testing_backend
from torch.testing._internal.common_utils import run_tests, TestCase
import torch


class TestTRTModule(TestCase):
def test_lowering_inplace_op(self):
class FullySupported(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
x = torch.ops.aten.add_.Tensor(x, y)
x = torch.ops.aten.relu_.default(x)
return x

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default}

# Trace module and set up custom backend to track intermediate graphs
fx_graph = torch.fx.symbolic_trace(FullySupported())
partitioned_graphs = []
custom_backend = partial(
fx_dynamo_testing_backend,
store_intermediate_graphs=partitioned_graphs,
)

# Invoke compilation
compiled_graph = torch.compile(fx_graph, backend=custom_backend)
compiled_graph(
torch.rand(
5,
).cuda(),
torch.rand(
5,
).cuda(),
)

# Iterate over intermediate graphs, attempt to match nodes
for fx_module in partitioned_graphs:
for _, submodule in fx_module.named_children():
for node in submodule.graph.nodes:

if node.op == "call_function" and node.target in expected_ops:
expected_ops.remove(node.target)

self.assertEqual(
len(expected_ops), 0, "All operators should have been decomposed"
)


if __name__ == "__main__":
run_tests()
68 changes: 68 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from torch_tensorrt.dynamo.torch_compile.lowering import partition
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from copy import deepcopy
import numpy as np


class TestPartitioning(TestCase):
def test_partition_fully_supported_one_op(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
0,
"Single operators should not be segmented",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
sum_ = torch.ops.aten.sub.Tensor(x, y)
concat_ = torch.ops.aten.cat.default(x, sum_)
relu_ = torch.ops.aten.relu.default(concat_)
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
return pow_

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

def test_partition_partially_supported_multi_op(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
sum_1 = torch.ops.aten.add.Tensor(x, y)
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
sum_ = np.sum(sum_1) + np.sum(sum_2)
relu_ = torch.ops.aten.relu.default(sum_)
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
return pow_

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
2,
"Unsupported operators interleave supported ones, expected 2 segments",
)


if __name__ == "__main__":
run_tests()
94 changes: 94 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from copy import deepcopy
from functools import partial
from typing import List, Sequence
import torch
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
partition,
)

from torch._dynamo.backends.common import fake_tensor_unsupported

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


@fake_tensor_unsupported
def fx_dynamo_testing_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
*,
store_intermediate_graphs: List,
):
"""Helper Dynamo backend exclusively for testing"""
custom_backend = partial(
compile_module_testing,
store_intermediate_graphs=store_intermediate_graphs,
)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(),
)


def compile_module_testing(
gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
*,
store_intermediate_graphs: List,
) -> torch.fx.GraphModule:
"""Helper compiler exclusively for testing"""
partitioned_module = partition(gm)

# Store intermediate graph from partitioned module
store_intermediate_graphs.append(deepcopy(partitioned_module))

return partitioned_module


def same_output_format(trt_output, torch_output, enforce_tensor_type=True):
# For each encountered collection type, ensure the torch and trt outputs agree
# on type and size, checking recursively through all member elements.
if isinstance(trt_output, tuple):
return (
isinstance(torch_output, tuple)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, list):
return (
isinstance(torch_output, list)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, dict):
return (
isinstance(torch_output, dict)
and (len(trt_output) == len(torch_output))
and (trt_output.keys() == torch_output.keys())
and all(
same_output_format(
trt_output[key], torch_output[key], enforce_tensor_type
)
for key in trt_output.keys()
)
)
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
raise AssertionError(
"Unsupported output type 'set' encountered in output format check."
)
elif enforce_tensor_type:
return type(trt_output) is type(torch_output)
else:
return True
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
raise ValueError(
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
)

return device