Skip to content

feat: Implement Input class support for FX backend. #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ TRTEngine::TRTEngine(
}

TRTEngine::~TRTEngine() {
rt.reset();
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
rt.reset();
}

void TRTEngine::disable_profiling() {
Expand Down
6 changes: 5 additions & 1 deletion py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def run(self):
if FX_ONLY:
ext_modules = None
packages = [
"torch_tensorrt",
"torch_tensorrt.fx",
"torch_tensorrt.fx.converters",
"torch_tensorrt.fx.passes",
Expand All @@ -358,6 +359,7 @@ def run(self):
"torch_tensorrt.fx.tracer.dispatch_tracer",
]
package_dir = {
"torch_tensorrt": "torch_tensorrt/",
"torch_tensorrt.fx": "torch_tensorrt/fx",
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
Expand Down Expand Up @@ -437,7 +439,9 @@ def run(self):
"bin/*",
"BUILD",
"WORKSPACE",
],
]
if not FX_ONLY
else ["_Input.py"]
},
exclude_package_data={
"": ["*.cpp"],
Expand Down
65 changes: 9 additions & 56 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from torch_tensorrt import _enums
from torch_tensorrt import _C


class Input(object):
Expand Down Expand Up @@ -41,6 +40,7 @@ class _ShapeMode(Enum):
DOMAIN_OFFSET = 2.0
low_tensor_domain_incl = 0.0
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
torch_dtype = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we derive torch_dtype from self.dtype?


def __init__(self, *args, **kwargs):
"""__init__ Method for torch_tensorrt.Input
Expand Down Expand Up @@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs):
)

if "dtype" in kwargs:
if isinstance(kwargs["dtype"], torch.dtype):
self.torch_dtype = kwargs["dtype"]

self.dtype = Input._parse_dtype(kwargs["dtype"])
self._explicit_set_dtype = True

Expand Down Expand Up @@ -173,59 +176,6 @@ def __str__(self) -> str:
else:
raise RuntimeError("Unknown input shape mode")

def _to_internal(self) -> _C.Input:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this taken out?

internal_in = _C.Input()
if self.shape_mode == Input._ShapeMode.DYNAMIC:
if not Input._supported_input_size_type(self.shape["min_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["min_shape"]))
+ " for min_shape"
)
else:
internal_in.min = self.shape["min_shape"]

if not Input._supported_input_size_type(self.shape["opt_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["opt_shape"]))
+ " for opt_shape"
)
else:
internal_in.opt = self.shape["opt_shape"]

if not Input._supported_input_size_type(self.shape["max_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["max_shape"]))
+ " for max_shape"
)
else:
internal_in.max = self.shape["max_shape"]
internal_in.input_is_dynamic = True
else:
if not Input._supported_input_size_type(self.shape):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape))
+ " for shape"
)
else:
internal_in.opt = self.shape
internal_in.input_is_dynamic = False

if self.dtype != _enums.dtype.unknown:
self._explicit_set_dtype = True
else:
self._explicit_set_dtype = False

internal_in.dtype = Input._parse_dtype(self.dtype)
internal_in._explicit_set_dtype = self._explicit_set_dtype
internal_in.format = Input._parse_format(self.format)

internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain)
return internal_in

@staticmethod
def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
Expand Down Expand Up @@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
Input.low_tensor_domain_incl,
Input.high_tensor_domain_excl,
)

elif len(domain) == 2:
domain_lo, domain_hi = domain

Expand Down Expand Up @@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
)

if self.shape_mode == Input._ShapeMode.STATIC:
return torch.randn(self.shape).to(dtype=self.dtype)
return torch.randn(self.shape).to(
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
)
else:
return torch.randn(self.shape[optimization_profile_field]).to(
dtype=self.dtype
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
)
3 changes: 1 addition & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ def compile(
return torch_tensorrt.fx.compile(
module,
inputs,
lower_precision=lower_precision,
max_batch_size=inputs[0].size(0),
explicit_batch_dimension=True,
lower_precision=lower_precision,
dynamic_batch=False,
**kwargs,
)
Expand Down
38 changes: 38 additions & 0 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .types import Shape, ShapeRange
from .utils import get_dynamic_dims
from torch_tensorrt._Input import Input


def generate_input_specs(inputs, lower_setting, additional_inputs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. These API are used in many internal products.

Expand Down Expand Up @@ -116,6 +117,43 @@ def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"
assert isinstance(tensors, (list, tuple))
return [cls.from_tensor(t) for t in tensors]

@classmethod
Copy link
Contributor

@frank-wei frank-wei Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we use the input_obj is general interface.

def from_input(cls, input_obj: Input) -> "InputTensorSpec":
"""
Produce a list of InputTenosrSpec named tuples which contain
the information of all the given PyTorch tensors.

Args:
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors.

Returns:
A list of InputTensorSpec named tuples.
"""
assert isinstance(input_obj, Input)
input_spec = None
if isinstance(input_obj.shape, dict):
min_shape = input_obj.shape["min_shape"]
opt_shape = input_obj.shape["opt_shape"]
max_shape = input_obj.shape["max_shape"]
dyn_shape = []
for min, opt, max in zip(min_shape, opt_shape, max_shape):
if min == opt == max:
dyn_shape.append(min)
else:
dyn_shape.append(-1)
dtype = input_obj.torch_dtype
input_spec = cls(
shape=dyn_shape,
dtype=dtype,
shape_ranges=[(min_shape, opt_shape, max_shape)],
)
else:
shape = input_obj.shape
dtype = input_obj.torch_dtype
input_spec = cls(shape=shape, dtype=dtype)

return input_spec

@classmethod
def from_tensors_with_dynamic_batch_size(
cls,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.fx as fx
import torch.nn as nn
import torch_tensorrt
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.splitter_base import SplitResult

Expand All @@ -29,8 +30,8 @@
def compile(
module: nn.Module,
input,
min_acc_module_size: int = 10,
max_batch_size: int = 2048,
min_acc_module_size: int = 10,
max_workspace_size=1 << 25,
explicit_batch_dimension=False,
lower_precision=LowerPrecision.FP16,
Expand Down Expand Up @@ -302,6 +303,7 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
conversion_fn = fp16_conversion_fn

inputs = tuple(conversion_fn(x) for x in inputs)

if lower_setting.is_aten:
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
inputs, additional_inputs
Expand Down
40 changes: 29 additions & 11 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
from torch_tensorrt.fx.utils import LowerPrecision

from ..input_tensor_spec import generate_input_specs
from torch_tensorrt import _Input
from ..input_tensor_spec import generate_input_specs, InputTensorSpec

from ..lower_setting import LowerSetting
from ..observer import Observer
from ..passes.remove_duplicate_output_args import remove_duplicate_output_args
from .graph_opts import common_subexpression_elimination
from .pass_utils import extract_example_tensors_from_input

from .lower_basic_pass import ( # noqa
fix_clamp_numerical_limits_to_fp16,
Expand Down Expand Up @@ -165,6 +166,7 @@ def _split_pass(self) -> PassManager:
)
)
)

return PassManager.build_from_passlist(passes)

def _trt_lower_pass(self) -> PassManager:
Expand Down Expand Up @@ -192,13 +194,17 @@ def lower_func(split_result: SplitResult) -> nn.Module:
_LOGGER.info(f"Now lowering submodule {submod_name}")
lowering_start_time = datetime.datetime.now()

self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
)
if self._trt_input:
self.lower_setting.input_specs = self._trt_input
else:
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
)

lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
Expand Down Expand Up @@ -262,7 +268,13 @@ def _default_replace_mutable_op_pass(self) -> PassManager:
def build_trt_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we start a new func instead of change this build_trt_lower_pipeline behavior?

self._input = extract_example_tensors_from_input(input)
self._trt_input = []
for input_obj in input:
if isinstance(input_obj, _Input.Input):
self._trt_input.append(InputTensorSpec.from_input(input_obj))

self._additional_input = additional_input
passes = []

Expand All @@ -278,7 +290,13 @@ def build_trt_lower_pipeline(
def build_aten2trt_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input

self._input = extract_example_tensors_from_input(input)
self._trt_input = []
for input_obj in input:
if isinstance(input_obj, _Input.Input):
self._trt_input.append(InputTensorSpec.from_input(input_obj))

self._additional_input = additional_input
passes = []
passes.append(
Expand Down
31 changes: 28 additions & 3 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch import fx
from torch.fx.passes.shape_prop import ShapeProp
from torch_tensorrt import _Input

# Create an alias for module input type to avoid littering pyre-ignore for Any
# throughout the file.
Expand All @@ -21,6 +22,30 @@
FINAL_CHECK_RTOL_MULTIPLIER: float = 10


def extract_example_tensors_from_input(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pass real tensor as input to the lowering workflow? Why do we need the input_obj as input?

inputs: Any, device: torch.device = torch.device("cuda")
):
input_tensors = []
for input_obj in inputs:
if isinstance(input_obj, _Input.Input):
if isinstance(input_obj.shape, dict):
input_tensors.append(
input_obj.example_tensor(optimization_profile_field="opt_shape").to(
device
)
)
else:
input_tensors.append(input_obj.example_tensor().to(device))
elif isinstance(input_obj, torch.Tensor):
input_tensors.append(input_obj)
else:
raise ValueError(
"Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor"
)

return input_tensors


class RelaxAccuracyCheckMode:
"""
Basically a context manager that controls a global variable that controls
Expand Down Expand Up @@ -114,10 +139,10 @@ def pass_with_validation(
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
input_tensors = extract_example_tensors_from_input(input)
res0 = module(*input_tensors)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

res1 = processed_module(*input_tensors)
tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
Expand Down
Loading