Skip to content

fix: Repair version checking system for Torch #2118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _find_lib(name, paths):

from torch_tensorrt import fx

if version.parse(torch.__version__) >= version.parse("2.1.dev"):
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import backend

Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ def get_build_info() -> str:

def set_device(gpu_id):
_C.set_device(gpu_id)


def sanitized_torch_version() -> str:
return (
torch.__version__
if ".nv" not in torch.__version__
else torch.__version__.split(".nv")[0]
)
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from packaging import version
from torch_tensorrt._util import sanitized_torch_version

if version.parse(torch.__version__) >= version.parse("2.1.dev"):
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from torch_tensorrt.dynamo import fx_ts_compat
from .backend import compile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt._util import sanitized_torch_version

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,7 +44,9 @@ def forward(self, x, y):
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
return reshape
"""
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
if version.parse(sanitized_torch_version()) < version.parse(
"2.1.0.dev20230620"
):
expected_graph = expected_graph.replace("num_users", "#users")

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn as nn

import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup
from torch_tensorrt._util import sanitized_torch_version

from torch.testing._internal.common_utils import run_tests, TestCase

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn):
return add
""".strip()

if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
if version.parse(sanitized_torch_version()) < version.parse(
"2.1.0.dev20230620"
):
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")

assert (
Expand All @@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn):
return (x,)
""".strip()

if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
if version.parse(sanitized_torch_version()) < version.parse(
"2.1.0.dev20230620"
):
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")

assert (
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from packaging import version
from torch_tensorrt._util import sanitized_torch_version

import torch

if version.parse(torch.__version__) >= version.parse("2.dev"):
if version.parse(sanitized_torch_version()) >= version.parse("2.dev"):
import torch._dynamo as torchdynamo

from torch.fx.passes.infra.pass_base import PassResult
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
replace_op_with_indices,
run_const_fold,
)

from torch_tensorrt._util import sanitized_torch_version
from .types import Shape, TRTDataType


Expand Down Expand Up @@ -160,7 +160,7 @@ def nested_decorator(f: Callable):
def function_wrapper(*args, **kwargs):
# Parse minimum and current Torch versions
min_version = version.parse(min_torch_version)
current_version = version.parse(torch.__version__)
current_version = version.parse(sanitized_torch_version())

if current_version < min_version:
raise AssertionError(
Expand Down