-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Implement Input class support for FX backend. #1763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
import torch | ||
|
||
from torch_tensorrt import _enums | ||
from torch_tensorrt import _C | ||
|
||
|
||
class Input(object): | ||
|
@@ -41,6 +40,7 @@ class _ShapeMode(Enum): | |
DOMAIN_OFFSET = 2.0 | ||
low_tensor_domain_incl = 0.0 | ||
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET | ||
torch_dtype = None | ||
|
||
def __init__(self, *args, **kwargs): | ||
"""__init__ Method for torch_tensorrt.Input | ||
|
@@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs): | |
) | ||
|
||
if "dtype" in kwargs: | ||
if isinstance(kwargs["dtype"], torch.dtype): | ||
self.torch_dtype = kwargs["dtype"] | ||
|
||
self.dtype = Input._parse_dtype(kwargs["dtype"]) | ||
self._explicit_set_dtype = True | ||
|
||
|
@@ -173,59 +176,6 @@ def __str__(self) -> str: | |
else: | ||
raise RuntimeError("Unknown input shape mode") | ||
|
||
def _to_internal(self) -> _C.Input: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this taken out? |
||
internal_in = _C.Input() | ||
if self.shape_mode == Input._ShapeMode.DYNAMIC: | ||
if not Input._supported_input_size_type(self.shape["min_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["min_shape"])) | ||
+ " for min_shape" | ||
) | ||
else: | ||
internal_in.min = self.shape["min_shape"] | ||
|
||
if not Input._supported_input_size_type(self.shape["opt_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["opt_shape"])) | ||
+ " for opt_shape" | ||
) | ||
else: | ||
internal_in.opt = self.shape["opt_shape"] | ||
|
||
if not Input._supported_input_size_type(self.shape["max_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["max_shape"])) | ||
+ " for max_shape" | ||
) | ||
else: | ||
internal_in.max = self.shape["max_shape"] | ||
internal_in.input_is_dynamic = True | ||
else: | ||
if not Input._supported_input_size_type(self.shape): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape)) | ||
+ " for shape" | ||
) | ||
else: | ||
internal_in.opt = self.shape | ||
internal_in.input_is_dynamic = False | ||
|
||
if self.dtype != _enums.dtype.unknown: | ||
self._explicit_set_dtype = True | ||
else: | ||
self._explicit_set_dtype = False | ||
|
||
internal_in.dtype = Input._parse_dtype(self.dtype) | ||
internal_in._explicit_set_dtype = self._explicit_set_dtype | ||
internal_in.format = Input._parse_format(self.format) | ||
|
||
internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) | ||
return internal_in | ||
|
||
@staticmethod | ||
def _supported_input_size_type(input_size: Any) -> bool: | ||
if isinstance(input_size, torch.Size): | ||
|
@@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: | |
Input.low_tensor_domain_incl, | ||
Input.high_tensor_domain_excl, | ||
) | ||
|
||
elif len(domain) == 2: | ||
domain_lo, domain_hi = domain | ||
|
||
|
@@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor | |
) | ||
|
||
if self.shape_mode == Input._ShapeMode.STATIC: | ||
return torch.randn(self.shape).to(dtype=self.dtype) | ||
return torch.randn(self.shape).to( | ||
dtype=self.dtype if not self.torch_dtype else self.torch_dtype | ||
) | ||
else: | ||
return torch.randn(self.shape[optimization_profile_field]).to( | ||
dtype=self.dtype | ||
dtype=self.dtype if not self.torch_dtype else self.torch_dtype | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
from .types import Shape, ShapeRange | ||
from .utils import get_dynamic_dims | ||
from torch_tensorrt._Input import Input | ||
|
||
|
||
def generate_input_specs(inputs, lower_setting, additional_inputs=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. These API are used in many internal products. |
||
|
@@ -116,6 +117,43 @@ def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec" | |
assert isinstance(tensors, (list, tuple)) | ||
return [cls.from_tensor(t) for t in tensors] | ||
|
||
@classmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume we use the input_obj is general interface. |
||
def from_input(cls, input_obj: Input) -> "InputTensorSpec": | ||
""" | ||
Produce a list of InputTenosrSpec named tuples which contain | ||
the information of all the given PyTorch tensors. | ||
|
||
Args: | ||
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. | ||
|
||
Returns: | ||
A list of InputTensorSpec named tuples. | ||
""" | ||
assert isinstance(input_obj, Input) | ||
input_spec = None | ||
if isinstance(input_obj.shape, dict): | ||
min_shape = input_obj.shape["min_shape"] | ||
opt_shape = input_obj.shape["opt_shape"] | ||
max_shape = input_obj.shape["max_shape"] | ||
dyn_shape = [] | ||
for min, opt, max in zip(min_shape, opt_shape, max_shape): | ||
if min == opt == max: | ||
dyn_shape.append(min) | ||
else: | ||
dyn_shape.append(-1) | ||
dtype = input_obj.torch_dtype | ||
input_spec = cls( | ||
shape=dyn_shape, | ||
dtype=dtype, | ||
shape_ranges=[(min_shape, opt_shape, max_shape)], | ||
) | ||
else: | ||
shape = input_obj.shape | ||
dtype = input_obj.torch_dtype | ||
input_spec = cls(shape=shape, dtype=dtype) | ||
|
||
return input_spec | ||
|
||
@classmethod | ||
def from_tensors_with_dynamic_batch_size( | ||
cls, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,14 @@ | |
from torch.fx.passes.shape_prop import ShapeProp | ||
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult | ||
from torch_tensorrt.fx.utils import LowerPrecision | ||
|
||
from ..input_tensor_spec import generate_input_specs | ||
from torch_tensorrt import _Input | ||
from ..input_tensor_spec import generate_input_specs, InputTensorSpec | ||
|
||
from ..lower_setting import LowerSetting | ||
from ..observer import Observer | ||
from ..passes.remove_duplicate_output_args import remove_duplicate_output_args | ||
from .graph_opts import common_subexpression_elimination | ||
from .pass_utils import extract_example_tensors_from_input | ||
|
||
from .lower_basic_pass import ( # noqa | ||
fix_clamp_numerical_limits_to_fp16, | ||
|
@@ -165,6 +166,7 @@ def _split_pass(self) -> PassManager: | |
) | ||
) | ||
) | ||
|
||
return PassManager.build_from_passlist(passes) | ||
|
||
def _trt_lower_pass(self) -> PassManager: | ||
|
@@ -192,13 +194,17 @@ def lower_func(split_result: SplitResult) -> nn.Module: | |
_LOGGER.info(f"Now lowering submodule {submod_name}") | ||
lowering_start_time = datetime.datetime.now() | ||
|
||
self.lower_setting.input_specs = generate_input_specs( | ||
submod_inputs, | ||
self.lower_setting, | ||
additional_submodule_inputs[submod_name] | ||
if additional_submodule_inputs | ||
else None, | ||
) | ||
if self._trt_input: | ||
self.lower_setting.input_specs = self._trt_input | ||
else: | ||
self.lower_setting.input_specs = generate_input_specs( | ||
submod_inputs, | ||
self.lower_setting, | ||
additional_submodule_inputs[submod_name] | ||
if additional_submodule_inputs | ||
else None, | ||
) | ||
|
||
lowered_module = self._lower_func( | ||
submod, submod_inputs, self.lower_setting, submod_name | ||
) | ||
|
@@ -262,7 +268,13 @@ def _default_replace_mutable_op_pass(self) -> PassManager: | |
def build_trt_lower_pipeline( | ||
self, input: Input, additional_input: Optional[Input] = None | ||
) -> PassManager: | ||
self._input = input | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we start a new func instead of change this |
||
self._input = extract_example_tensors_from_input(input) | ||
self._trt_input = [] | ||
for input_obj in input: | ||
if isinstance(input_obj, _Input.Input): | ||
self._trt_input.append(InputTensorSpec.from_input(input_obj)) | ||
|
||
self._additional_input = additional_input | ||
passes = [] | ||
|
||
|
@@ -278,7 +290,13 @@ def build_trt_lower_pipeline( | |
def build_aten2trt_lower_pipeline( | ||
self, input: Input, additional_input: Optional[Input] = None | ||
) -> PassManager: | ||
self._input = input | ||
|
||
self._input = extract_example_tensors_from_input(input) | ||
self._trt_input = [] | ||
for input_obj in input: | ||
if isinstance(input_obj, _Input.Input): | ||
self._trt_input.append(InputTensorSpec.from_input(input_obj)) | ||
|
||
self._additional_input = additional_input | ||
passes = [] | ||
passes.append( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import torch | ||
from torch import fx | ||
from torch.fx.passes.shape_prop import ShapeProp | ||
from torch_tensorrt import _Input | ||
|
||
# Create an alias for module input type to avoid littering pyre-ignore for Any | ||
# throughout the file. | ||
|
@@ -21,6 +22,30 @@ | |
FINAL_CHECK_RTOL_MULTIPLIER: float = 10 | ||
|
||
|
||
def extract_example_tensors_from_input( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we pass real tensor as input to the lowering workflow? Why do we need the |
||
inputs: Any, device: torch.device = torch.device("cuda") | ||
): | ||
input_tensors = [] | ||
for input_obj in inputs: | ||
if isinstance(input_obj, _Input.Input): | ||
if isinstance(input_obj.shape, dict): | ||
input_tensors.append( | ||
input_obj.example_tensor(optimization_profile_field="opt_shape").to( | ||
device | ||
) | ||
) | ||
else: | ||
input_tensors.append(input_obj.example_tensor().to(device)) | ||
elif isinstance(input_obj, torch.Tensor): | ||
input_tensors.append(input_obj) | ||
else: | ||
raise ValueError( | ||
"Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" | ||
) | ||
|
||
return input_tensors | ||
|
||
|
||
class RelaxAccuracyCheckMode: | ||
""" | ||
Basically a context manager that controls a global variable that controls | ||
|
@@ -114,10 +139,10 @@ def pass_with_validation( | |
*args, | ||
**kwargs, | ||
) -> fx.GraphModule: | ||
res0 = module(*input) | ||
input_tensors = extract_example_tensors_from_input(input) | ||
res0 = module(*input_tensors) | ||
processed_module = pass_(module, input, *args, **kwargs) | ||
res1 = processed_module(*input) | ||
|
||
res1 = processed_module(*input_tensors) | ||
tensor_res_0 = _collect_tensors(res0) | ||
tensor_res_1 = _collect_tensors(res1) | ||
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we derive torch_dtype from self.dtype?