Skip to content

fix: Add support for truncate_long_and_double in Dynamo #1969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -40,7 +41,7 @@ def compile(
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
calibrator=None,
truncate_long_and_double=False,
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation=False,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
Expand All @@ -54,7 +55,7 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
+ "truncate_long_and_double, torch_executed_ops, pass_through_build_failures}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -86,6 +87,7 @@ def compile(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
truncate_long_and_double=truncate_long_and_double,
**kwargs,
)

Expand All @@ -109,6 +111,7 @@ def create_backend(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -131,6 +134,7 @@ def create_backend(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
truncate_long_and_double=truncate_long_and_double,
)

return partial(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
MAX_WORKSPACE_SIZE = 20 << 30
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
TRUNCATE_LONG_AND_DOUBLE = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -19,3 +20,4 @@ class CompilationSettings:
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
partition,
get_submod_inputs,
)

from torch_tensorrt.dynamo.backend.utils import repair_long_or_double_input
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._dynamo.backends.common import fake_tensor_unsupported
Expand Down Expand Up @@ -130,6 +132,47 @@ def _compile_module(
partitioned_module, submodule, sample_inputs
)

# Ensure all submodule inputs do not require a gradient
for param in submodule_inputs:
param.requires_grad = False

# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
num_submodule_inputs = len(submodule_inputs)
repaired_outputs_once = False

# For each input to the TRT subgraph, check if its type is long/double
for position in range(num_submodule_inputs):
param = submodule_inputs[position]

# If the data type of the input is long/double, insert necessary
# casts to replace the operation
if param.dtype in (torch.int64, torch.float64):
# Ensure outputs are only repaired once per submodule to avoid
# unnecessary ops showing up in the graph
if not repaired_outputs_once:
submodule_outputs = submodule(*submodule_inputs)

repair_long_or_double_input(
partitioned_module,
position,
name,
None if repaired_outputs_once else submodule_outputs,
param.dtype,
)

repaired_outputs_once = True

# Repair submodule inputs in accordance with inserted casts
dtype_32bit = (
torch.int32 if (param.dtype == torch.int64) else torch.float32
)
submodule_inputs = (
submodule_inputs[:position]
+ (param.to(dtype_32bit),)
+ submodule_inputs[position + 1 :]
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
Expand Down
122 changes: 122 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from torch_tensorrt.dynamo.backend.lowering import partition
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from copy import deepcopy
from torch_tensorrt.dynamo import compile
from utils import lower_graph_testing
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT


class Test64BitInput(TestCase):
def test_float64_input_full_support(self):
class FullySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.mean.dim(
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
)

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)

self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

inputs = [
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
]

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph,
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)

def test_int64_input_partial_support(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.div.Tensor_mode(
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
)

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
unexpected_ops = {torch.ops.aten.add.Tensor}

inputs = [
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
]

(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
fx_graph,
inputs,
unexpected_ops=unexpected_ops,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
testing_partitioning=True,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEquals(
len(partitioned_graphs),
1,
"Without control flow breaks, there should only be a single graph",
)
self.assertEquals(
len(list(partitioned_graphs[0].named_children())),
1,
"Certain operators are set to run in Torch, expected 1 segment",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph,
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading