Skip to content

Commit 66ea963

Browse files
committed
dynamic test case for full_like_to_full
1 parent f61d10c commit 66ea963

File tree

1 file changed

+64
-2
lines changed

1 file changed

+64
-2
lines changed

tests/py/dynamo/lowering/test_decompositions.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22
import torch_tensorrt
33
from parameterized import parameterized
4+
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
45
from torch.testing._internal.common_utils import TestCase, run_tests
56

6-
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
7-
87

98
class TestLowering(TestCase):
109
def test_lowering_inplace_op(self):
@@ -483,6 +482,69 @@ def forward(self, x):
483482
f"FullLike TRT outputs don't match with the original model.",
484483
)
485484

485+
def test_lowering_full_like_to_full_dynamic_module(self):
486+
class FullLike(torch.nn.Module):
487+
def __init__(self, *args, **kwargs) -> None:
488+
super().__init__(*args, **kwargs)
489+
490+
def forward(self, x):
491+
c = torch.ops.aten.add(x, x)
492+
y = torch.ops.aten.full_like.default(c, 2)
493+
d = y + c
494+
return d
495+
496+
# Operations expected to be removed in the traced graph after decompositions
497+
expected_ops = {torch.ops.aten.add.Tensor}
498+
unexpected_ops = {torch.ops.aten.full_like.default}
499+
500+
inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
501+
torch._dynamo.mark_dynamic(inputs[0], 0, min=1, max=3)
502+
fx_graph = torch.fx.symbolic_trace(FullLike())
503+
504+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
505+
fx_graph,
506+
inputs,
507+
expected_ops=expected_ops,
508+
unexpected_ops=unexpected_ops,
509+
min_block_size=1,
510+
)
511+
512+
self.assertEqual(
513+
len(unexpected_ops_seen),
514+
0,
515+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
516+
)
517+
518+
self.assertEqual(
519+
len(expected_ops_unseen),
520+
0,
521+
f"The following expected ops were not encountered: {expected_ops_unseen}",
522+
)
523+
524+
torch._dynamo.reset()
525+
526+
# Validate that the results between Torch and Torch-TRT are similar
527+
optimized_model = torch_tensorrt.compile(
528+
fx_graph,
529+
"torch_compile",
530+
inputs,
531+
min_block_size=1,
532+
truncate_double=True,
533+
pass_through_build_failures=True,
534+
)
535+
optimized_model_results = optimized_model(*inputs).detach().cpu()
536+
torch_model_results = fx_graph(*inputs).detach().cpu()
537+
538+
max_diff = float(
539+
torch.max(torch.abs(optimized_model_results - torch_model_results))
540+
)
541+
self.assertAlmostEqual(
542+
max_diff,
543+
0,
544+
DECIMALS_OF_AGREEMENT,
545+
f"FullLike TRT outputs don't match with the original model.",
546+
)
547+
486548
def test_lowering_empty_like_module(self):
487549
class emptyLike(torch.nn.Module):
488550
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)