-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Implement Dynamic shapes + fallback support for export path #2271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
87 commits
Select commit
Hold shift + click to select a range
f1f202e
feat: Move tracing to use aot export apis
peri044 abaf047
chore: minor changes
peri044 bb1f3cf
chore: minor changes
peri044 3d05b4d
chore: Rebase with main
peri044 8d99be5
chore: rebase
peri044 0aad214
chore: minor logging updates
peri044 8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive f6969be
Key op fixes for failing tests
gs-olive bad1594
fix: Add constant folding utility to freezing
gs-olive db56dd6
chore: Move to new export APIs
peri044 bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive 6bd3c64
Key op fixes for failing tests
gs-olive 248073f
fix: Add constant folding utility to freezing
gs-olive 3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive 3b6e1e7
Key op fixes for failing tests
gs-olive 2107d8e
fix: Add constant folding utility to freezing
gs-olive fd5a41e
chore: add BERT test case
peri044 f047651
chore: remove pdb
peri044 de9795d
feat: Implement dynamic shapes feature
peri044 ab76c0d
chore: rebase
peri044 5f2a4f3
chore: minor update
peri044 566fbb0
Merge branch 'export_prototype' into dyn_export
peri044 4949549
chore: refactor
peri044 e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive 9610ba7
Key op fixes for failing tests
gs-olive e19aae7
fix: Add constant folding utility to freezing
gs-olive d95c360
chore: Add constraints for dynamic inputs during export
peri044 2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 d7f2477
chore: rebase with export_prototype
peri044 b50d362
chore: enable truncate long and double inputs
peri044 91b47fb
chore: updates
peri044 51266db
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 2005db7
fix: Add constant folding utility to freezing
gs-olive a8cb1fe
fix: Move tracer code into try/except
gs-olive 7ff9309
Custom implementation of AOT for compile
gs-olive 692921e
Move fixes into Dynamo directory
gs-olive e926724
chore: rebase
peri044 0cfd23b
Merge branch 'export_prototype' into dyn_export
peri044 2c85bc7
chore: minor changes
peri044 1de79b3
chore: add device updates
peri044 33ddf46
chore: minor updates
peri044 39e7d98
chore: refactor prepare_inputs
peri044 760eda6
chore: minor updates
peri044 b3c9666
chore: updates
peri044 fbfb8ef
chore: updates
peri044 b7056a1
chore: add tests and update GHA
peri044 27681c2
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 056cbf3
fix: Add constant folding utility to freezing
gs-olive ece276c
fix: Move tracer code into try/except
gs-olive 73a0bce
Custom implementation of AOT for compile
gs-olive 890ba72
Move fixes into Dynamo directory
gs-olive 980dc1c
chore: rebase
peri044 dfc4899
Move fixes into Dynamo directory
gs-olive 09b099a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 157bb2d
chore: updates
peri044 0005a31
Move fixes into Dynamo directory
gs-olive 5526bca
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 3420fb0
chore: updates
peri044 4a0afd3
Merge branch 'export_prototype' into dyn_export
peri044 399f929
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 4b44ff2
fix: Add constant folding utility to freezing
gs-olive a94a075
fix: Move tracer code into try/except
gs-olive 4e308f1
Custom implementation of AOT for compile
gs-olive 95d3f98
Move fixes into Dynamo directory
gs-olive 529262a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 16d19e8
Merge branch 'export_prototype' into dyn_export
peri044 aee529b
chore: rebase
peri044 89acc8e
Merge branch 'export_prototype' into dyn_export
peri044 6858c0a
chore: rebase
peri044 695dc9b
chore: address review comments
peri044 a5cdd24
chore: updates
peri044 24922fc
chore: updates
peri044 645816e
chore: updates
peri044 708ac64
chore: rebase with main
peri044 9bcaf49
chore: update docs
peri044 6704cb7
chore: update docs
peri044 560c779
chore: update docs
peri044 03f5f2d
chore: rebase
peri044 0349810
chore: fix tests
peri044 912fcab
chore: updates
peri044 31c09b2
chore: revert harness tracer changes
peri044 9f0a589
chore: address review comments
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,8 @@ class DataType { | |
enum Value : int8_t { | ||
/// INT64 | ||
kLong, | ||
/// FP64 | ||
kDouble, | ||
/// FP32 | ||
kFloat, | ||
/// FP16 | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
.. _runtime: | ||
|
||
Dynamic shapes with Torch-TensorRT | ||
==================================== | ||
|
||
By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. | ||
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. | ||
In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for | ||
these range of input shapes. An example usage of static and dynamic shapes is as follows. | ||
|
||
NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
# Compile with static shapes | ||
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32) | ||
# or compile with dynamic shapes | ||
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], | ||
opt_shape=[4, 3, 224, 224], | ||
max_shape=[8, 3, 224, 224], | ||
dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) | ||
|
||
Under the hood | ||
-------------- | ||
|
||
There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default). | ||
|
||
- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs) | ||
|
||
In the tracing phase, we use torch.export along with the constraints. In the case of | ||
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please | ||
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_ | ||
for detailed information on how to set constraints. In short, we create new inputs for | ||
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. | ||
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. | ||
|
||
- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT) | ||
|
||
In the conversion to TensorRT, we use the user provided dynamic shape inputs. | ||
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the | ||
intermediate output shapes which can be used in case the graph has a mix of Pytorch | ||
and TensorRT submodules. | ||
|
||
Custom Constraints | ||
------------------ | ||
|
||
Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, | ||
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows | ||
|
||
.. code-block:: python | ||
|
||
for dim in constraint_dims: | ||
if min_shape[dim] > 1: | ||
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) | ||
if max_shape[dim] > 1: | ||
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) | ||
|
||
Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them. | ||
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs. | ||
|
||
.. code-block:: python | ||
|
||
constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0)) | ||
|
||
|
||
If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions | ||
# Assume the model has two inputs | ||
model = MyModel() | ||
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda() | ||
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda() | ||
|
||
dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], | ||
opt_shape=[4, 14], | ||
max_shape=[8, 14], | ||
dtype=torch.int32), | ||
torch_tensorrt.Input(min_shape=[1, 14], | ||
opt_shape=[4, 14], | ||
max_shape=[8, 14], | ||
dtype=torch.int32)] | ||
|
||
# Export the model with additional constraints | ||
constraints = [] | ||
# The following constraints are automatically added by Torch-TensorRT in the | ||
# general case when you call torch_tensorrt.compile directly on MyModel() | ||
constraints.append(dynamic_dim(torch_input_1, 0) < 8) | ||
constraints.append(dynamic_dim(torch_input_2, 0) < 8) | ||
# This is an additional constraint as instructed by Torchdynamo | ||
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0)) | ||
with unittest.mock.patch( | ||
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) | ||
): | ||
graph_module = export( | ||
model, (torch_input_1, torch_input_2), constraints=constraints | ||
).module() | ||
|
||
# Use the dynamo.compile API | ||
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec) | ||
|
||
Limitations | ||
----------- | ||
|
||
If there are operations in the graph that use the dynamic dimension of the input, Pytorch | ||
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and | ||
the compilation results in undefined behavior. We plan to add support for these operators and implement | ||
robust support for shape tensors in the next release. Here is an example of the limitation described above | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) | ||
|
||
def forward(self, x): | ||
x = self.avgpool(x) | ||
out = torch.flatten(x, 1) | ||
return out | ||
|
||
model = MyModel().eval().cuda() | ||
# Compile with dynamic shapes | ||
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), | ||
opt_shape=(4, 512, 1, 1), | ||
max_shape=(8, 512, 1, 1), | ||
dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) | ||
|
||
|
||
The traced graph of `MyModule()` looks as follows | ||
|
||
.. code-block:: python | ||
|
||
Post export graph: graph(): | ||
%arg0_1 : [num_users=2] = placeholder[target=arg0_1] | ||
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {}) | ||
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {}) | ||
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {}) | ||
return (view,) | ||
|
||
|
||
Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support | ||
which would be a part of our next release. | ||
|
||
Workaround (BERT static compilation example) | ||
------------------------------------------ | ||
|
||
In the case where you encounter the issues mentioned in the **Limitations** section, | ||
you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, | ||
we can pad them accordingly. This is only a workaround until we address the limitations. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
from transformers.utils.fx import symbolic_trace as transformers_trace | ||
|
||
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() | ||
|
||
# Input sequence length is 20. | ||
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") | ||
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") | ||
|
||
model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda() | ||
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) | ||
model_outputs = model(input, input2) | ||
|
||
# If you have a sequence of length 14, pad 6 zero tokens and run inference | ||
# or recompile for sequence length of 14. | ||
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") | ||
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") | ||
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) | ||
model_outputs = model(input, input2) | ||
|
||
|
||
Dynamic shapes with ir=torch_compile | ||
------------------------------------ | ||
|
||
``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend | ||
configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes. | ||
In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
inputs = torch.randn((1, 3, 224, 224), dtype=float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs) | ||
# Compilation happens when you call the model | ||
trt_gm(inputs) | ||
|
||
# Recompilation happens with modified batch size | ||
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if
"torch_tensor"
is provided, but the shape is also Dynamic - would it be an issue if this Tensor had the shape ofmin
ormax
instead ofopt
? This could be refactored to override or validate the specified tensor if the shape is dynamic.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the usecase you are saying is if the user provides
x = torch_tensorrt.Input(min_shape=<>, opt_shape=<>, max_shape=<>, torch_tensor=<random_tensor>)
This wouldn't be an issue, as we ignore torch_tensor in this case of dynamic shape compilation.
TensorRT/py/torch_tensorrt/dynamo/partitioning/common.py
Line 48 in 695dc9b
We could maybe pass a warning to users that torch_tensor is not being used in such cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - I believe the feature ask from #2323 is to allow providing tensors for each of
min
,opt
, andmax
. Couldtorch_tensor
be instead either a list/tuple of 3 tensors (min
,opt
,max
) or a single tensor?