Skip to content

Commit 0fa3010

Browse files
committed
fix: Remove input aliasing with builtin ops
- Add replacements for inplace builtin operators with their out-of-place equivalents - Add utility to automatically perform replacement prior to AOT tracing - Add test cases to verify inplace operators are replaced accurately
1 parent c728b60 commit 0fa3010

File tree

4 files changed

+137
-1
lines changed

4 files changed

+137
-1
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
16-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
16+
from torch_tensorrt.dynamo.lowering import (
17+
get_decompositions,
18+
replace_builtin_inplace_ops,
19+
)
1720
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1821
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1922

@@ -74,6 +77,7 @@ def _pretraced_backend(
7477
with unittest.mock.patch.object(
7578
fake_mode, "allow_non_fake_inputs", True
7679
), fake_mode:
80+
replace_builtin_inplace_ops(gm)
7781
# Invoke AOTAutograd to translate operators to aten
7882
graph_module = aot_export_for_compile(
7983
gm,

py/torch_tensorrt/dynamo/lowering/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from ._replace_inplace_ops import replace_builtin_inplace_ops
56
from .substitutions import * # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
BUILTIN_TRANSLATION = {
9+
operator.ipow: operator.pow,
10+
operator.imul: operator.mul,
11+
operator.imatmul: operator.matmul,
12+
operator.ifloordiv: operator.floordiv,
13+
operator.itruediv: operator.truediv,
14+
operator.imod: operator.mod,
15+
operator.iadd: operator.add,
16+
operator.isub: operator.sub,
17+
operator.ilshift: operator.lshift,
18+
operator.irshift: operator.rshift,
19+
operator.iand: operator.and_,
20+
operator.ixor: operator.xor,
21+
operator.ior: operator.or_,
22+
}
23+
24+
25+
def replace_builtin_inplace_ops(gm: torch.fx.GraphModule) -> None:
26+
"""Replaces inplace builtins from Python's operator class
27+
28+
Replaces inplace builtins with out-of-place equivalent ops
29+
"""
30+
for node in gm.graph.nodes:
31+
# If a node uses one of the inplace builtins
32+
# Replace it with its out-of-place equivalent
33+
if node.target in BUILTIN_TRANSLATION:
34+
out_of_place_op = BUILTIN_TRANSLATION[node.target]
35+
36+
# Replace inplace operator node and delete
37+
with gm.graph.inserting_before(node):
38+
out_of_place = gm.graph.call_function(
39+
out_of_place_op,
40+
args=node.args,
41+
kwargs=node.kwargs,
42+
)
43+
44+
logger.debug(f"Replacing {node.target} with {out_of_place.target}")
45+
46+
node.replace_all_uses_with(out_of_place)
47+
gm.graph.erase_node(node)
48+
49+
gm.graph.lint()
50+
gm.recompile()

tests/py/dynamo/backend/test_specialized_models.py

+81
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,86 @@ def forward(self, x):
157157
torch._dynamo.reset()
158158

159159

160+
class TestInputModifications(TestCase):
161+
def test_input_modifications_add(self):
162+
class InplaceAdd(torch.nn.Module):
163+
def forward(self, x):
164+
x += 3
165+
y = x + 1
166+
return y
167+
168+
inputs = [
169+
torch.rand(
170+
3,
171+
5,
172+
7,
173+
).cuda(),
174+
]
175+
176+
fx_graph = torch.fx.symbolic_trace(InplaceAdd())
177+
178+
# Validate that the results between Torch and Torch-TRT are similar
179+
optimized_model = torch_tensorrt.compile(
180+
fx_graph,
181+
"torch_compile",
182+
inputs,
183+
min_block_size=1,
184+
pass_through_build_failures=True,
185+
)
186+
optimized_model_results = optimized_model(*inputs).detach().cpu()
187+
torch_model_results = fx_graph(*inputs).detach().cpu()
188+
189+
max_diff = float(
190+
torch.max(torch.abs(optimized_model_results - torch_model_results))
191+
)
192+
self.assertAlmostEqual(
193+
max_diff,
194+
0,
195+
msg=f"InplaceAdd TRT outputs don't match with the original model.",
196+
)
197+
torch._dynamo.reset()
198+
199+
def test_input_modifications_mul(self):
200+
class InplaceMul(torch.nn.Module):
201+
def forward(self, x):
202+
x *= 5.0
203+
x *= 1.9
204+
y = x + 1
205+
y /= 1.3
206+
return y
207+
208+
inputs = [
209+
torch.rand(
210+
1,
211+
3,
212+
5,
213+
7,
214+
).cuda(),
215+
]
216+
217+
fx_graph = torch.fx.symbolic_trace(InplaceMul())
218+
219+
# Validate that the results between Torch and Torch-TRT are similar
220+
optimized_model = torch_tensorrt.compile(
221+
fx_graph,
222+
"torch_compile",
223+
inputs,
224+
min_block_size=1,
225+
pass_through_build_failures=True,
226+
)
227+
optimized_model_results = optimized_model(*inputs).detach().cpu()
228+
torch_model_results = fx_graph(*inputs).detach().cpu()
229+
230+
max_diff = float(
231+
torch.max(torch.abs(optimized_model_results - torch_model_results))
232+
)
233+
self.assertAlmostEqual(
234+
max_diff,
235+
0,
236+
msg=f"InplaceMul TRT outputs don't match with the original model.",
237+
)
238+
torch._dynamo.reset()
239+
240+
160241
if __name__ == "__main__":
161242
run_tests()

0 commit comments

Comments
 (0)