Skip to content

Commit 0fa5c13

Browse files
committed
fix: Remove input aliasing with builtin ops
- Add replacements for inplace builtin operators with their out-of-place equivalents - Add utility to automatically perform replacement prior to AOT tracing - Add test cases to verify inplace operators are replaced accurately
1 parent 07b492b commit 0fa5c13

File tree

5 files changed

+149
-3
lines changed

5 files changed

+149
-3
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
16-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
16+
from torch_tensorrt.dynamo.lowering import (
17+
get_decompositions,
18+
replace_builtin_inplace_ops,
19+
)
1720
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1821
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1922

@@ -74,6 +77,8 @@ def _pretraced_backend(
7477
with unittest.mock.patch.object(
7578
fake_mode, "allow_non_fake_inputs", True
7679
), fake_mode:
80+
replace_builtin_inplace_ops(gm)
81+
7782
# Invoke AOTAutograd to translate operators to aten
7883
graph_module = aot_export_for_compile(
7984
gm,

py/torch_tensorrt/dynamo/lowering/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from ._replace_inplace_ops import replace_builtin_inplace_ops
56
from .substitutions import * # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
BUILTIN_TRANSLATION = {
9+
operator.ipow: operator.pow,
10+
operator.imul: operator.mul,
11+
operator.imatmul: operator.matmul,
12+
operator.ifloordiv: operator.floordiv,
13+
operator.itruediv: operator.truediv,
14+
operator.imod: operator.mod,
15+
operator.iadd: operator.add,
16+
operator.isub: operator.sub,
17+
operator.ilshift: operator.lshift,
18+
operator.irshift: operator.rshift,
19+
operator.iand: operator.and_,
20+
operator.ixor: operator.xor,
21+
operator.ior: operator.or_,
22+
}
23+
24+
25+
def replace_builtin_inplace_ops(gm: torch.fx.GraphModule) -> None:
26+
"""Replaces inplace builtins from Python's operator class
27+
28+
Replaces inplace builtins with out-of-place equivalent ops
29+
"""
30+
for node in gm.graph.nodes:
31+
# If a node uses one of the inplace builtins
32+
# Replace it with its out-of-place equivalent
33+
if node.target in BUILTIN_TRANSLATION:
34+
out_of_place_op = BUILTIN_TRANSLATION[node.target]
35+
36+
# Replace inplace operator node and delete
37+
with gm.graph.inserting_before(node):
38+
out_of_place = gm.graph.call_function(
39+
out_of_place_op,
40+
args=node.args,
41+
kwargs=node.kwargs,
42+
)
43+
44+
logger.debug(f"Replacing {node.target} with {out_of_place.target}")
45+
46+
node.replace_all_uses_with(out_of_place)
47+
gm.graph.erase_node(node)
48+
49+
gm.graph.lint()
50+
gm.recompile()

tests/py/dynamo/backend/test_specialized_models.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt
33
from torch.testing._internal.common_utils import TestCase, run_tests
44

5-
from ..testing_utilities import lower_graph_testing
5+
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
66

77

88
class TestFakeTensors(TestCase):
@@ -57,6 +57,7 @@ def forward(self, x):
5757
self.assertAlmostEqual(
5858
max_diff,
5959
0,
60+
DECIMALS_OF_AGREEMENT,
6061
msg=f"MulInt TRT outputs don't match with the original model.",
6162
)
6263
torch._dynamo.reset()
@@ -113,6 +114,7 @@ def forward(self, x):
113114
self.assertAlmostEqual(
114115
max_diff,
115116
0,
117+
DECIMALS_OF_AGREEMENT,
116118
msg=f"AddFloat TRT outputs don't match with the original model.",
117119
)
118120

@@ -157,5 +159,88 @@ def forward(self, x):
157159
torch._dynamo.reset()
158160

159161

162+
class TestInputModifications(TestCase):
163+
def test_input_modifications_add(self):
164+
class InplaceAdd(torch.nn.Module):
165+
def forward(self, x):
166+
x += 3
167+
y = x + 1
168+
return y
169+
170+
inputs = [
171+
torch.rand(
172+
3,
173+
5,
174+
7,
175+
).cuda(),
176+
]
177+
178+
fx_graph = torch.fx.symbolic_trace(InplaceAdd())
179+
180+
# Validate that the results between Torch and Torch-TRT are similar
181+
optimized_model = torch_tensorrt.compile(
182+
fx_graph,
183+
"torch_compile",
184+
inputs,
185+
min_block_size=1,
186+
pass_through_build_failures=True,
187+
)
188+
optimized_model_results = optimized_model(*inputs).detach().cpu()
189+
torch_model_results = fx_graph(*inputs).detach().cpu()
190+
191+
max_diff = float(
192+
torch.max(torch.abs(optimized_model_results - torch_model_results))
193+
)
194+
self.assertAlmostEqual(
195+
max_diff,
196+
0,
197+
DECIMALS_OF_AGREEMENT,
198+
msg=f"InplaceAdd TRT outputs don't match with the original model.",
199+
)
200+
torch._dynamo.reset()
201+
202+
def test_input_modifications_mul(self):
203+
class InplaceMul(torch.nn.Module):
204+
def forward(self, x):
205+
x *= 5.0
206+
x *= 1.9
207+
y = x + 1
208+
y /= 1.3
209+
return y
210+
211+
inputs = [
212+
torch.rand(
213+
1,
214+
3,
215+
5,
216+
7,
217+
).cuda(),
218+
]
219+
220+
fx_graph = torch.fx.symbolic_trace(InplaceMul())
221+
222+
# Validate that the results between Torch and Torch-TRT are similar
223+
optimized_model = torch_tensorrt.compile(
224+
fx_graph,
225+
"torch_compile",
226+
inputs,
227+
min_block_size=1,
228+
pass_through_build_failures=True,
229+
)
230+
optimized_model_results = optimized_model(*inputs).detach().cpu()
231+
torch_model_results = fx_graph(*inputs).detach().cpu()
232+
233+
max_diff = float(
234+
torch.max(torch.abs(optimized_model_results - torch_model_results))
235+
)
236+
self.assertAlmostEqual(
237+
max_diff,
238+
0,
239+
DECIMALS_OF_AGREEMENT,
240+
msg=f"InplaceMul TRT outputs don't match with the original model.",
241+
)
242+
torch._dynamo.reset()
243+
244+
160245
if __name__ == "__main__":
161246
run_tests()

tests/py/dynamo/testing_utilities.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from torch._dynamo.utils import detect_fake_mode
88
from torch_tensorrt.dynamo import partitioning
99
from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold
10-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
10+
from torch_tensorrt.dynamo.lowering import (
11+
get_decompositions,
12+
replace_builtin_inplace_ops,
13+
)
1114
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1215

1316
DECIMALS_OF_AGREEMENT = 4
@@ -39,6 +42,8 @@ def fx_dynamo_testing_backend(
3942
with unittest.mock.patch.object(
4043
fake_mode, "allow_non_fake_inputs", True
4144
), fake_mode:
45+
replace_builtin_inplace_ops(gm)
46+
4247
# Invoke AOTAutograd to translate operators to aten
4348
graph_module = aot_export_for_compile(
4449
gm,

0 commit comments

Comments
 (0)