Skip to content

feat: support deconv (1d, 2d, and Nd) dynamo converter #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ def aten_ops_softmax(

@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split_with_sizes.default,
capability_validator=dynamic_unsupported_with_args([1]),
)
) # type: ignore[misc]
def aten_ops_split(
network: TRTNetwork,
target: Target,
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def aten_ops_less(


def conv_param_validator(conv_node: Node) -> bool:
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])


@dynamo_tensorrt_converter(
Expand All @@ -1344,20 +1344,37 @@ def aten_ops_convolution(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.conv.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
stride=args[3],
padding=args[4],
dilation=args[5],
groups=args[8],
)
is_transposed = args[6]
if not is_transposed:
return impl.conv.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
stride=args[3],
padding=args[4],
dilation=args[5],
groups=args[8],
)
else:
return impl.deconv.deconvNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_deconv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
stride=args[3],
padding=args[4],
dilation=args[5],
groups=args[8],
)


@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc]
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
cast,
condition,
conv,
deconv,
elementwise,
embedding,
linear,
Expand Down
140 changes: 140 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Optional, Sequence, Union

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import (
extend_attr_to_tuple,
get_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
get_dyn_range,
has_dynamic_shape,
mark_as_int8_layer,
set_layer_name,
to_numpy,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def deconvNd(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
is_deconv1d: bool,
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
stride: Optional[Union[int, Sequence[int]]],
padding: Optional[Union[int, Sequence[int]]],
groups: Optional[int],
dilation: Optional[Union[int, Sequence[int]]],
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution."

if is_deconv1d:
# Apply an unsqueeze operation to transform the deconv1d problem into deconv2d
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_deconv1d", input, -1
)

# Process bias terms
if isinstance(bias, (torch.Tensor, np.ndarray)):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")

elif bias is not None:
raise RuntimeError(
f"Deconvolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
)

# Process weight terms
if network.has_explicit_precision or isinstance(weight, TRTTensor):
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the deconvolution is 1d
if is_deconv1d:
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_weight", weight, -1
)

elif isinstance(weight, (torch.Tensor, np.ndarray)):
# Transform the weight constant into a Numpy array
weight = to_numpy(weight)

# Append new dimension (unsqueeze) if the deconvolution is 1d
if is_deconv1d:
weight = np.expand_dims(weight, axis=-1)

else:
raise RuntimeError(
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
)

# add deconv layer
deconv_layer = network.add_deconvolution_nd(
input=input,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)

# If the weight is a TRTTensor, set it as an input of the layer
if isinstance(weight, TRTTensor):
deconv_layer.set_input(1, weight)

# If the bias is a TRTTensor, set it as an input of the layer
if isinstance(bias, TRTTensor):
deconv_layer.set_input(2, bias)

# Cast certain fields to tuples, in accordance with TRT requirements
padding = (padding,) if isinstance(padding, int) else padding
stride = (stride,) if isinstance(stride, int) else stride
dilation = (dilation,) if isinstance(dilation, int) else dilation

# Expand parameters manually for Conv1D computations
if is_deconv1d:
padding = (tuple(padding) + (0,)) if padding is not None else padding
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
dilation = (
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
)

set_layer_name(deconv_layer, target, name, source_ir)

# Set relevant attributes of deconvolution layer
if padding is not None:
deconv_layer.padding_nd = padding
if stride is not None:
deconv_layer.stride_nd = stride
if dilation is not None:
deconv_layer.dilation_nd = dilation
if groups is not None:
deconv_layer.num_groups = groups

# Handle quantization cases
if scale is not None and zero_point is not None:
# Assume the dtype of activation is torch.quint8
mark_as_int8_layer(deconv_layer, get_dyn_range(scale, zero_point, torch.quint8))

result = deconv_layer.get_output(0)

if is_deconv1d:
# Apply a squeeze operation to transform the deconv2d problem back into deconv1d
result = impl.squeeze.squeeze(
network, target, source_ir, name + "_squeeze_deconv1d", result, -1
)

return result
Loading