-
Notifications
You must be signed in to change notification settings - Fork 370
feat: support for many padding dynamo converters #2482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
94c7cb4
feat: support constant padding dynamo converter
zewenli98 c63cbaf
feat: support reflection padding dynamo converters for 1D, 2D, and 3D
zewenli98 215df67
feat: support replication padding dynamo converters for 1D, 2D, and 3D
zewenli98 7e0e477
feat: support circular padding dynamo converters for 1D, 2D, and 3D
zewenli98 78c2f43
feat: support pad dynamo converter
zewenli98 ccc5d3c
fix a concat bug
zewenli98 d908601
update constant pad
zewenli98 7534fa2
implement paddings via TRT ISliceLayer with different SliceMode
zewenli98 1b21c9b
fix import bug
zewenli98 2e8c094
fix bugs
zewenli98 a682814
add some small modifications
zewenli98 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
linear, | ||
matmul, | ||
normalization, | ||
pad, | ||
permutation, | ||
pool, | ||
reduce, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
from typing import Optional, Sequence, Union | ||
|
||
import tensorrt as trt | ||
from torch.fx.node import Target | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor | ||
from torch_tensorrt.fx.converters.converter_utils import ( | ||
has_dynamic_shape, | ||
set_layer_name, | ||
) | ||
from torch_tensorrt.fx.types import TRTTensor | ||
|
||
""" | ||
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0. | ||
Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding | ||
mode and clamp, and supports padding output with dynamic shape. | ||
""" | ||
|
||
|
||
def constant_padNd( | ||
ctx: ConversionContext, | ||
target: Union[Target, str], | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
pad: Sequence[int], | ||
value: Union[int, float] = 0, | ||
) -> TRTTensor: | ||
if has_dynamic_shape(input.shape): | ||
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." | ||
|
||
rank = len(input.shape) | ||
|
||
if len(pad) // 2 > rank: | ||
raise RuntimeError( | ||
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." | ||
) | ||
|
||
start_list = [0] * len(input.shape) | ||
new_shape = input.shape | ||
|
||
for i in range(0, len(pad) // 2): | ||
start_list[-i - 1] = -pad[i * 2] | ||
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] | ||
|
||
stride_list = [1] * len(new_shape) | ||
layer = ctx.net.add_slice( | ||
input, | ||
start=tuple(start_list), | ||
shape=tuple(new_shape), | ||
stride=tuple(stride_list), | ||
) | ||
value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype) | ||
layer.set_input(4, value_const) | ||
layer.mode = trt.SliceMode.FILL | ||
|
||
set_layer_name(layer, target, name, source_ir) | ||
return layer.get_output(0) | ||
|
||
|
||
def reflection_padNd( | ||
ctx: ConversionContext, | ||
target: Union[Target, str], | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
padding: Sequence[int], | ||
) -> TRTTensor: | ||
if has_dynamic_shape(input.shape): | ||
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." | ||
|
||
rank = len(input.shape) | ||
|
||
if len(padding) // 2 > rank: | ||
raise RuntimeError( | ||
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." | ||
) | ||
|
||
start_list = [0] * len(input.shape) | ||
new_shape = input.shape | ||
|
||
for i in range(0, len(padding) // 2): | ||
start_list[-i - 1] = -padding[i * 2] | ||
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] | ||
|
||
stride_list = [1] * len(new_shape) | ||
layer = ctx.net.add_slice( | ||
input, | ||
start=tuple(start_list), | ||
shape=tuple(new_shape), | ||
stride=tuple(stride_list), | ||
) | ||
layer.mode = trt.SliceMode.REFLECT | ||
|
||
set_layer_name(layer, target, name, source_ir) | ||
return layer.get_output(0) | ||
|
||
|
||
def replication_padNd( | ||
ctx: ConversionContext, | ||
target: Union[Target, str], | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
padding: Sequence[int], | ||
) -> TRTTensor: | ||
if has_dynamic_shape(input.shape): | ||
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." | ||
|
||
rank = len(input.shape) | ||
|
||
if len(padding) // 2 > rank: | ||
raise RuntimeError( | ||
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." | ||
) | ||
|
||
start_list = [0] * len(input.shape) | ||
new_shape = input.shape | ||
|
||
for i in range(0, len(padding) // 2): | ||
start_list[-i - 1] = -padding[i * 2] | ||
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] | ||
|
||
stride_list = [1] * len(new_shape) | ||
layer = ctx.net.add_slice( | ||
input, | ||
start=tuple(start_list), | ||
shape=tuple(new_shape), | ||
stride=tuple(stride_list), | ||
) | ||
layer.mode = trt.SliceMode.CLAMP | ||
|
||
set_layer_name(layer, target, name, source_ir) | ||
return layer.get_output(0) | ||
|
||
|
||
def circular_padNd( | ||
ctx: ConversionContext, | ||
target: Union[Target, str], | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
pad: Sequence[int], | ||
) -> TRTTensor: | ||
if has_dynamic_shape(input.shape): | ||
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." | ||
|
||
rank = len(input.shape) | ||
|
||
if len(pad) // 2 > rank: | ||
raise RuntimeError( | ||
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." | ||
) | ||
|
||
start_list = [0] * len(input.shape) | ||
new_shape = input.shape | ||
zewenli98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for i in range(0, len(pad) // 2): | ||
start_list[-i - 1] = -pad[i * 2] | ||
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] | ||
|
||
stride_list = [1] * len(new_shape) | ||
layer = ctx.net.add_slice( | ||
input, | ||
start=tuple(start_list), | ||
shape=tuple(new_shape), | ||
stride=tuple(stride_list), | ||
) | ||
layer.mode = trt.SliceMode.WRAP | ||
|
||
set_layer_name(layer, target, name, source_ir) | ||
return layer.get_output(0) | ||
|
||
|
||
def pad( | ||
ctx: ConversionContext, | ||
target: Union[Target, str], | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
pad: Sequence[int], | ||
mode: str = "constant", | ||
value: Optional[float] = None, | ||
) -> TRTTensor: | ||
if mode == "constant": | ||
return constant_padNd( | ||
ctx, | ||
target, | ||
source_ir, | ||
f"{name}_{mode}", | ||
input, | ||
pad, | ||
value if value is not None else 0, | ||
) | ||
elif mode == "reflect": | ||
return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) | ||
elif mode == "replicate": | ||
return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) | ||
elif mode == "circular": | ||
return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) | ||
else: | ||
raise RuntimeError( | ||
f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}' | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.