Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip dummy inference and run_shape_analysis #3212

Merged
merged 30 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
split the PR, add dummy inference for converter test
  • Loading branch information
lanluo-nvidia committed Oct 21, 2024
commit ef54cfce4f63fad4e5ecfb0da827b5e25302550c
52 changes: 52 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
Expand All @@ -16,6 +18,7 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,6 +60,55 @@ def infer_module_output_dtypes(
return get_output_dtypes(outputs, truncate_double)


def infer_module_output_dtypes_for_test(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes


def interpret_module_to_result(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down
16 changes: 11 additions & 5 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.conversion._conversion import (
infer_module_output_dtypes_for_test,
)
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
Expand Down Expand Up @@ -273,7 +275,7 @@ def run_test(
atol=ATOL,
precision=dtype.f32,
check_dtype=True,
use_dynamo_tracer=True,
use_dynamo_tracer=False,
enable_passes=False,
propagate_shapes=False,
int32_reqd=False,
Expand Down Expand Up @@ -326,8 +328,10 @@ def run_test(

output_dtypes = None
if check_dtype:
output_dtypes = infer_module_output_dtypes(
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

Expand Down Expand Up @@ -399,7 +403,7 @@ def run_test_with_dynamic_shape(
rtol=RTOL,
atol=ATOL,
output_dtypes=None,
use_dynamo_tracer=True,
use_dynamo_tracer=False,
enable_passes=False,
use_example_tensors=True,
pyt_inputs=None,
Expand All @@ -426,8 +430,10 @@ def run_test_with_dynamic_shape(
)

if check_dtype:
output_dtypes = infer_module_output_dtypes(
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(2, 2, 3),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/test_acosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(2, 2, 3),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(2, 2, 4),
(1, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/conversion/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase):
(
"3d_dynamic_float",
(2, 1, 1),
(2, 2, 2),
(2, 2, 1),
(3, 2, 4),
torch.float,
),
Expand Down Expand Up @@ -234,7 +234,7 @@ def forward(self, x):
(
"3d_dynamic_dim_float",
(2, 1, 1),
(2, 2, 2),
(2, 2, 1),
(3, 2, 4),
torch.float,
2,
Expand All @@ -252,7 +252,7 @@ def forward(self, x):
(
"3d_dynamic_dim_bool",
(2, 1, 1),
(2, 2, 2),
(2, 2, 1),
(3, 2, 4),
torch.bool,
0,
Expand Down Expand Up @@ -285,7 +285,7 @@ def forward(self, x):
(
"3d_dynamic_dims_float",
(2, 1, 1),
(2, 2, 2),
(2, 2, 1),
(3, 2, 4),
torch.float,
[1, 2],
Expand Down
1 change: 0 additions & 1 deletion tests/py/dynamo/conversion/test_arange_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def forward(self, end_tensor):
use_example_tensors=False,
check_dtype=False,
pyt_inputs=[pyt_input],
use_dynamo_tracer=False,
)


Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/conversion/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ def forward(self, x, y):
min_shape=(16, 2, 3),
opt_shape=(16, 3, 3),
max_shape=(16, 32, 3),
name="x",
),
Input(
dtype=torch.float32,
min_shape=(16, 2, 3),
opt_shape=(16, 16, 3),
max_shape=(16, 32, 3),
name="y",
),
]
self.run_test_with_dynamic_shape(
Expand Down
6 changes: 1 addition & 5 deletions tests/py/dynamo/conversion/test_full_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def forward(self, shape):
)
]
self.run_test_with_dynamic_shape(
full(),
inputs,
use_example_tensors=False,
check_dtype=False,
use_dynamo_tracer=False,
full(), inputs, use_example_tensors=False, check_dtype=False
)

@parameterized.expand(
Expand Down
33 changes: 1 addition & 32 deletions tests/py/dynamo/conversion/test_ge_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.export import Dim
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

Expand Down Expand Up @@ -84,40 +83,10 @@ def forward(self, lhs_val, rhs_val):
inputs,
)

@parameterized.expand(
[
("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)),
]
)
def test_ge_dynamic_tensor_torch_export(self, *args):
class ge(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.ge.Tensor(lhs_val, rhs_val)

input_specs = [
Input(
min_shape=args[1],
opt_shape=args[2],
max_shape=args[3],
),
Input(
min_shape=args[4],
opt_shape=args[5],
max_shape=args[6],
),
]
dyn_dim = Dim("dyn_dim", min=2, max=4)
torch_export_dynamic_shapes = {"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}}

self.run_test_with_dynamic_shape(
ge(),
input_specs,
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
)

@parameterized.expand(
[
("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)),
("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)),
]
)
def test_ge_dynamic_tensor(self, *args):
Expand Down
11 changes: 2 additions & 9 deletions tests/py/dynamo/conversion/test_isinf_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,7 @@ def forward(self, input):
max_shape=(5, 3, 3),
dtype=torch.float32,
torch_tensor=torch.tensor(
(
[
[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]],
[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]],
]
),
([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]),
dtype=torch.float32,
).cuda(),
)
Expand All @@ -77,9 +72,7 @@ def forward(self, input):
opt_shape=(3, 2),
max_shape=(5, 3),
dtype=torch.int,
torch_tensor=torch.tensor(
([[-3, 2], [-2, 1], [1, 2]]), dtype=torch.int
).cuda(),
torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(),
)
]
self.run_test_with_dynamic_shape(
Expand Down
13 changes: 1 addition & 12 deletions tests/py/dynamo/conversion/test_isnan_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,7 @@ def forward(self, input):
max_shape=(5, 3, 3),
dtype=torch.float32,
torch_tensor=torch.tensor(
(
[
[
[3.2, float("nan"), 3.1],
[float("inf"), 1.1, float("nan")],
],
[
[3.2, float("nan"), 3.1],
[float("inf"), 1.1, float("nan")],
],
]
),
([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]),
dtype=torch.float32,
).cuda(),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_logical_and_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def forward(self, lhs_val, rhs_val):
(
"3d_dim_dtype_bool",
(1, 1, 1),
(2, 2, 3),
(1, 2, 3),
(3, 3, 3),
torch.bool,
),
Expand Down
Loading