Skip to content

chore/fix: Restructure Dynamo directory [7 / x] #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
from torch_tensorrt.dynamo.backend._defaults import (
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
import torch._dynamo as td

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.common import CompilationSettings
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
import io
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
from torch_tensorrt.dynamo.common import (
CompilationSettings,
InputTensorSpec,
TRTInterpreter,
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from torch_tensorrt.dynamo import compile
from utils import lower_graph_testing
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT


class TestTRTModuleNextCompilation(TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from torch_tensorrt.dynamo import compile
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT


class TestLowering(TestCase):
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import logging
from dataclasses import replace, fields

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.common import CompilationSettings, use_python_runtime_parser
from typing import Any, Union, Sequence, Dict
from torch_tensorrt import _Input, Device
from ..common_utils import use_python_runtime_parser


logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
from typing import Optional

from ._settings import CompilationSettings
from .input_tensor_spec import InputTensorSpec
from .fx2trt import TRTInterpreter, TRTInterpreterResult


logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
NO_IMPLICIT_BATCH_DIM_SUPPORT,
tensorrt_converter,
)
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
from .input_tensor_spec import InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .lower import compile # usort: skip #noqa

Expand Down
33 changes: 24 additions & 9 deletions py/torch_tensorrt/dynamo/fx_ts_compat/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,32 @@
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from torch_tensorrt.dynamo.common import (
TRTInterpreter,
TRTInterpreterResult,
use_python_runtime_parser,
)
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import PassFunc, validate_inference
from ..common_utils import use_python_runtime_parser
from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
)

logger = logging.getLogger(__name__)

Expand All @@ -35,24 +49,25 @@ def compile(
disable_tf32=False,
sparse_weights=False,
enabled_precisions=set(),
min_block_size: int = 3,
workspace_size=0,
min_block_size: int = MIN_BLOCK_SIZE,
workspace_size=WORKSPACE_SIZE,
dla_sram_size=1048576,
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
calibrator=None,
truncate_long_and_double=False,
require_full_compilation=False,
debug=False,
explicit_batch_dimension=False,
debug=DEBUG,
refit=False,
timing_cache_prefix="",
save_timing_cache=False,
cuda_graph_batch_size=-1,
is_aten=False,
use_python_runtime=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
use_python_runtime=USE_PYTHON_RUNTIME,
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
num_avg_timing_iters=1,
torch_executed_ops=[],
torch_executed_modules=[],
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.fx.passes.pass_manager import PassManager

from .input_tensor_spec import InputTensorSpec
from torch_tensorrt.dynamo.common import InputTensorSpec
from torch_tensorrt.fx.passes.lower_basic_pass import (
fuse_permute_linear,
fuse_permute_matmul,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt import _Input
from ..input_tensor_spec import InputTensorSpec
from torch_tensorrt.dynamo.common import InputTensorSpec

from ..lower_setting import LowerSetting
from torch_tensorrt.fx.observer import Observer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting
from torch_tensorrt.dynamo.common import InputTensorSpec


class TestTRTModule(TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.fx.passes import shape_prop
from torch.fx.passes.infra.pass_base import PassResult
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter
from torch_tensorrt.dynamo.common import InputTensorSpec, TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
compose_bmm,
compose_chunk,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from transformers import BertModel

from torch_tensorrt.dynamo.common_utils.test_utils import (
from torch_tensorrt.dynamo.common.test_utils import (
COSINE_THRESHOLD,
cosine_similarity,
)
Expand Down