Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__/
*.ipynb_checkpoints
*.pth
docs/converters.md
site
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

## [0.2.0] - 03/02/2021

- Added converter for ``torch.Tensor.flatten``
- Added converter for ``torch.nn.functional.conv2d`` and ``torch.nn.functional.conv3d``

### Added

- Added converter for ``torch.Tensor.expand``
Expand Down
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .AdaptiveAvgPool2d import *
from .BatchNorm1d import *
from .BatchNorm2d import *
from .conv_functional import *
from .Conv import *
from .Conv1d import *
from .Conv2d import *
Expand Down
127 changes: 127 additions & 0 deletions torch2trt/converters/conv_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.conv2d', enabled=trt_version() >= '7.0')
@tensorrt_converter('torch.nn.functional.conv3d', enabled=trt_version() >= '7.0')
def convert_Conv_trt7_functional(ctx):
input = get_arg(ctx, 'input', pos=0, default=None)
weight = get_arg(ctx, 'weight', pos=1, default=None)
bias = get_arg(ctx, 'bias', pos=2, default=None)
stride = get_arg(ctx, 'stride', pos=3, default=1)
padding = get_arg(ctx, 'padding', pos=4, default=0)
dilation = get_arg(ctx, 'dilation', pos=5, default=1)
groups = get_arg(ctx, 'groups', pos=6, default=1)

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return

input_dim = input.dim() - 2

out_channels = int(weight.shape[0])
kernel_size = tuple(weight.shape[2:])
if not isinstance(kernel_size, tuple):
kernel_size = (kernel_size, ) * input_dim

if not isinstance(stride, tuple):
stride = (stride, ) * input_dim

if not isinstance(padding, tuple):
padding = (padding, ) * input_dim

if not isinstance(dilation, tuple):
dilation = (dilation, ) * input_dim

kernel = weight.detach().cpu().numpy()

if bias is not None:
bias = bias.detach().cpu().numpy()

layer = ctx.network.add_convolution_nd(
input=input_trt,
num_output_maps=out_channels,
kernel_shape=kernel_size,
kernel=kernel,
bias=bias)
layer.stride_nd = stride
layer.padding_nd = padding
layer.dilation_nd = dilation

if groups is not None:
layer.num_groups = groups

output._trt = layer.get_output(0)


class FunctionalConv2d(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(*args, **kwargs)

def forward(self, x):
x = torch.nn.functional.conv2d(
x,
self.conv.weight,
self.conv.bias,
self.conv.stride,
self.conv.padding,
self.conv.dilation,
self.conv.groups
)
return x

class FunctionalConv3d(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.conv = torch.nn.Conv3d(*args, **kwargs)

def forward(self, x):
x = torch.nn.functional.conv3d(
x,
self.conv.weight,
self.conv.bias,
self.conv.stride,
self.conv.padding,
self.conv.dilation,
self.conv.groups
)
return x

@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_basic_trt7_functional():
return FunctionalConv2d(10, 5, kernel_size=1, stride=1, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_stride2_trt7_functional():
return FunctionalConv2d(10, 5, kernel_size=1, stride=2, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_kernel3_trt7_functional():
return FunctionalConv2d(10, 5, kernel_size=3, stride=2, padding=1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_dilation2_trt7_functional():
return FunctionalConv2d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_basic_trt7_functional():
return FunctionalConv3d(10, 5, kernel_size=1, stride=1, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_stride2_trt7_functional():
return FunctionalConv3d(10, 5, kernel_size=1, stride=2, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_kernel3_trt7_functional():
return FunctionalConv3d(10, 5, kernel_size=3, stride=2, padding=1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_dilation2_trt7_functional():
return FunctionalConv3d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2)
1 change: 1 addition & 0 deletions torch2trt/converters/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@tensorrt_converter('torch.Tensor.view')
@tensorrt_converter('torch.Tensor.squeeze')
@tensorrt_converter('torch.Tensor.unsqueeze')
@tensorrt_converter('torch.Tensor.flatten')
@tensorrt_converter('torch.squeeze')
@tensorrt_converter('torch.unsqueeze')
def convert_view(ctx):
Expand Down