Skip to content

Py38 compatibility #2189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -32,11 +34,11 @@ class _ShapeMode(Enum):
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
dtype: _enums.dtype = ( # type: ignore[name-defined]
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = ( # type: ignore[name-defined]
format: _enums.TensorFormat = (
_enums.TensorFormat.contiguous
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)

Expand Down Expand Up @@ -208,7 +210,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
return False

@staticmethod
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
def _parse_dtype(dtype: Any) -> _enums.dtype:
if isinstance(dtype, torch.dtype):
if dtype == torch.long:
return _enums.dtype.long
Expand Down Expand Up @@ -236,7 +238,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
)

@staticmethod
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
if dtype == _enums.dtype.long:
return torch.long
elif dtype == _enums.dtype.int32:
Expand All @@ -255,7 +257,7 @@ def is_trt_dtype(self) -> bool:
return bool(self.dtype != _enums.dtype.long)

@staticmethod
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
def _parse_format(format: Any) -> _enums.TensorFormat:
if isinstance(format, torch.memory_format):
if format == torch.contiguous_format:
return _enums.TensorFormat.contiguous
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard
from typing import Any, Callable, List, Optional, Sequence, Set

import torch
import torch.fx
Expand All @@ -12,6 +14,7 @@
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard


def _non_fx_input_interface(
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import copy
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import torch
import torch._dynamo as torchdynamo
Expand All @@ -22,7 +24,7 @@
)
from typing_extensions import TypeAlias

Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]


class DynamoConfig:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from functools import partial
from typing import Any, Callable, Sequence
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections.abc
import logging
from typing import Any, List, Optional, Set, Tuple
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import io
from typing import Sequence

Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from enum import Enum, auto
Expand Down Expand Up @@ -28,7 +30,7 @@
Dict[str, Argument],
str,
],
TRTTensor | Sequence[TRTTensor],
Union[TRTTensor, Sequence[TRTTensor]],
]


Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional, Sequence, Set

import torch
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
from typing import Any, Callable, Dict, Optional, Type

import torch
from torch._ops import OpOverload
from torch.fx import GraphModule, Node
from typing_extensions import TypeAlias

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from typing import Any, List, Optional, Tuple

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from dataclasses import fields, replace
from typing import Any, Callable, Dict, Optional, Sequence
Expand Down
12 changes: 7 additions & 5 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

Expand Down Expand Up @@ -39,7 +41,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
)


def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined]
def _parse_op_precision(precision: Any) -> _enums.dtype:
if isinstance(precision, torch.dtype):
if precision == torch.int8:
return _enums.dtype.int8
Expand All @@ -63,7 +65,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-de
)


def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined]
def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]:
parsed_precisions = set()
if any(isinstance(precisions, type) for type in [list, tuple, set]):
for p in precisions:
Expand All @@ -73,7 +75,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ig
return parsed_precisions


def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined]
def _parse_device_type(device: Any) -> _enums.DeviceType:
if isinstance(device, torch.device):
if device.type == "cuda":
return _C.DeviceType.gpu
Expand Down Expand Up @@ -346,10 +348,10 @@ def TensorRTCompileSpec(
device: torch.device | Device = Device._current_device(),
disable_tf32: bool = False,
sparse_weights: bool = False,
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, # type: ignore[name-defined]
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
refit: bool = False,
debug: bool = False,
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
capability: _enums.EngineCapability = _enums.EngineCapability.default,
num_avg_timing_iters: int = 1,
workspace_size: int = 0,
dla_sram_size: int = 1048576,
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/ts/_compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, List, Optional, Sequence, Set, Tuple

import torch
Expand Down