Skip to content

chore: Switch converter tests to generate standalone ops using fx.symbolic_trace #2361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 5, 2023
105 changes: 23 additions & 82 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Callable, List, Optional, Set, Tuple

import torch
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.infra.pass_base import PassResult
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt import Input
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -14,19 +12,6 @@
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
compose_bmm,
compose_chunk,
compose_getitem_slice,
remove_ops,
replace_aten_op_with_indices,
replace_aten_reshape_alias_with_replace,
replace_builtin_ops,
replace_native_layernorm_with_layernorm,
replace_transpose_mm_op_with_linear,
run_const_fold,
)
from torch_tensorrt.fx.passes.pass_utils import chain_passes

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,8 +47,6 @@ def run_test(
self,
mod,
inputs,
expected_ops,
unexpected_ops,
interpreter,
rtol,
atol,
Expand All @@ -76,10 +59,6 @@ def run_test(
cuda_inputs.append(i.cuda())

mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
if unexpected_ops:
self.assert_unexpected_op(mod, unexpected_ops)
start = time.perf_counter()
interpreter_result = interpreter.run(precision=precision)
sec = time.perf_counter() - start
Expand Down Expand Up @@ -215,75 +194,44 @@ def generate_graph(
self,
mod: torch.nn.Module,
original_inputs: List[torch.Tensor],
expected_ops: Set[Callable],
unexpected_ops: Optional[Set[Callable]] = None,
customized_passes: List[Callable] = None,
disable_passes: bool = False,
use_dynamo_tracer: bool,
enable_passes: bool,
):
# Torchdynamo+aot proxytensor tracer
# Below are common passes
passes_list = [
compose_bmm,
compose_chunk,
compose_getitem_slice,
replace_aten_reshape_alias_with_replace,
replace_aten_op_with_indices,
replace_transpose_mm_op_with_linear, # after compose_bmm
replace_native_layernorm_with_layernorm,
remove_ops,
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
]
# Combine with customized passes specific to any model
if customized_passes:
passes_list.extend(customized_passes)

if disable_passes:
passes_list = []

fx_module, _ = aten_tracer.trace(mod, original_inputs)
for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module
fx_module(*original_inputs)

fx_module = run_const_fold(fx_module)
if use_dynamo_tracer:
fx_module = torch._dynamo.export(
mod,
*original_inputs,
aten_graph=True,
assume_static_by_default=True,
tracing_mode="real",
).graph_module
else:
fx_module = torch.fx.symbolic_trace(mod)
if enable_passes:
fx_module = apply_lowering_passes(fx_module, original_inputs)
_LOGGER.info(f"FX graph= {fx_module.graph}")

if len(expected_ops):
self.assert_has_op(fx_module, expected_ops)
if unexpected_ops:
self.assert_unexpected_op(fx_module, unexpected_ops)

return fx_module

def run_test(
self,
mod,
inputs,
expected_ops,
unexpected_ops=None,
apply_passes=None,
rtol=1e-03,
atol=1e-03,
precision=torch.float,
check_dtype=True,
disable_passes=False,
output_dtypes=None,
use_dynamo_tracer=False,
enable_passes=False,
):
mod.eval()
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
use_dynamo_tracer=use_dynamo_tracer,
enable_passes=enable_passes,
)

if apply_passes is not None:
pass_tracer = chain_passes(*apply_passes)
mod = pass_tracer(mod, inputs)

# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_long_and_double=True)
Expand All @@ -297,8 +245,6 @@ def run_test(
super().run_test(
mod,
inputs,
expected_ops,
unexpected_ops,
interp,
rtol,
atol,
Expand All @@ -310,22 +256,19 @@ def run_test_with_dynamic_shape(
self,
mod,
input_specs,
expected_ops,
unexpected_ops=None,
rtol=1e-03,
atol=1e-03,
disable_passes=False,
output_dtypes=None,
use_dynamo_tracer=False,
enable_passes=False,
):
mod.eval()
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
use_dynamo_tracer=use_dynamo_tracer,
enable_passes=enable_passes,
)

# Previous instance of the interpreter auto-casted 64-bit inputs
Expand All @@ -341,6 +284,4 @@ def run_test_with_dynamic_shape(
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(
mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol
)
super().run_test(mod, inputs_max, interp, rtol, atol)
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_abs_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAbsConverter(DispatchTestCase):
def test_abs_float(self, input_shape, dtype):
class abs(nn.Module):
def forward(self, input):
return torch.abs(input)
return torch.ops.aten.abs.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
abs(),
inputs,
expected_ops={torch.ops.aten.abs.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_abs_int(self, input_shape, dtype, low, high):
class abs(nn.Module):
def forward(self, input):
return torch.abs(input)
return torch.ops.aten.abs.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
abs(),
inputs,
expected_ops={torch.ops.aten.abs.default},
output_dtypes=[torch.int],
)

Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAcosConverter(DispatchTestCase):
def test_acos_float(self, input_shape, dtype):
class acos(nn.Module):
def forward(self, input):
return torch.acos(input)
return torch.ops.aten.acos.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
expected_ops={torch.ops.aten.acos.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_acos_int(self, input_shape, dtype, low, high):
class acos(nn.Module):
def forward(self, input):
return torch.acos(input)
return torch.ops.aten.acos.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
expected_ops={torch.ops.aten.acos.default},
)


Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_acosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAcoshConverter(DispatchTestCase):
def test_acosh_float(self, input_shape, dtype):
class acosh(nn.Module):
def forward(self, input):
return torch.acosh(input)
return torch.ops.aten.acosh.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
acosh(),
inputs,
expected_ops={torch.ops.aten.acosh.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_acosh_int(self, input_shape, dtype, low, high):
class acosh(nn.Module):
def forward(self, input):
return torch.acosh(input)
return torch.ops.aten.acosh.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
acosh(),
inputs,
expected_ops={torch.ops.aten.acosh.default},
)


Expand Down
28 changes: 5 additions & 23 deletions tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,11 @@


class TestAdaptiveAvgPoolConverter(DispatchTestCase):
def test_adaptive_avgpool_mean(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
return self.pool(x)

inputs = [torch.randn(1, 3, 256, 256)]
self.run_test(
TestModule(),
inputs,
expected_ops={torch.ops.aten.mean.dim},
)

@parameterized.expand(
[
((64, 64),),
((128, 64),),
(64,),
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
]
)
def test_adaptive_avgpool(
Expand All @@ -46,7 +30,7 @@ def forward(self, x):
self.run_test(
TestModule(),
inputs,
expected_ops={torch.ops.aten._adaptive_avg_pool2d.default},
use_dynamo_tracer=True,
)

def test_adaptive_avgpool_with_dynamic_shape(self):
Expand All @@ -66,9 +50,7 @@ def forward(self, x):
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
expected_ops={torch.ops.aten._adaptive_avg_pool2d.default},
TestModule(), input_specs, use_dynamo_tracer=True
)

@parameterized.expand(
Expand All @@ -94,7 +76,7 @@ def forward(self, x):
self.run_test(
TestModule(),
inputs,
expected_ops={torch.ops.aten._adaptive_avg_pool3d.default},
use_dynamo_tracer=True,
)

def test_adaptive_avgpool3d_with_dynamic_shape(self):
Expand All @@ -118,7 +100,7 @@ def forward(self, x):
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
expected_ops={torch.ops.aten._adaptive_avg_pool3d.default},
use_dynamo_tracer=True,
)

# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
Expand Down
Loading