Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2d00347
Update dummy_converters.py
JonathanSKent Mar 20, 2021
febda4d
add equivalent 3d functions for most 2d functions
dmenig Jul 8, 2021
8ad6b55
Merge branch '3d' of https://github.com/hyperfraise/torch2trt into hy…
Jul 15, 2021
e62afd6
fixed bug in max_pool3d stride->stride_nd
Jul 15, 2021
9a11257
updated changelog to include new 3d operations
Jul 15, 2021
003af69
Merge branch 'hyperfraise-3d'
Jul 15, 2021
f6ea401
increment to version 0.3.0
Jul 15, 2021
c50039a
Merge pull request #594 from NVIDIA-AI-IOT/v0.3.0
jaybdub Jul 15, 2021
f1f96cc
Merge branch 'patch-1' of https://github.com/JonathanSKent/torch2trt …
Jul 19, 2021
eb07e6d
Merge pull request #595 from NVIDIA-AI-IOT/JonathanSKent-patch-1
jaybdub Jul 19, 2021
8f74290
switch to builder config
jaybdub Jul 27, 2021
4f42f91
handle trt8 resize align corners api
jaybdub Jul 27, 2021
72a8518
added psnr for unit tests
Aug 1, 2021
b86649b
fixed qat doc and made the output of test.py cleaner
Aug 1, 2021
1b1293c
fixed display precision
Aug 2, 2021
4393e4f
added color for psnr failure in the unit tests
Aug 2, 2021
311f328
Merge pull request #602 from SrivastavaKshitij/fix_test_and_doc
jaybdub Aug 9, 2021
6a418f9
Add torch.clone to converters.
chaoz-dev Sep 20, 2021
faf38b2
Add torch.nn.functional.max_pool1d to converters.
chaoz-dev Sep 21, 2021
82795eb
Set layer precision.
chaoz-dev Sep 26, 2021
572d422
Set layer precision.
chaoz-dev Sep 26, 2021
1107330
update changelog
jaybdub Sep 29, 2021
dd9950e
Merge pull request #599 from jaybdub/trt-8
jaybdub Sep 29, 2021
689cec7
added docker files for trt7/8 testing
jaybdub Sep 29, 2021
15b7f89
Merge pull request #641 from jaybdub/docker_files
jaybdub Sep 29, 2021
d2ebdaf
Merge pull request #633 from chaoz-dev/chaoz-dev/converters-clone
jaybdub Sep 29, 2021
0400b38
Merge pull request #634 from chaoz-dev/chaoz-dev/converters-maxpool1d
jaybdub Sep 29, 2021
5522e23
added matmul
orilador Oct 18, 2021
5839e9a
Merge https://github.com/NVIDIA-AI-IOT/torch2trt into support_trt_8_m…
orilador Oct 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

## [Master]

- Added support for TensorRT 8

## [0.3.0] - 07/15/2021

- Added converter for ``torch.nn.functional.adaptive_avg_pool3d``
- Added converter for ``torch.nn.functional.adaptive_max_pool3d``
- Added converter for ``torch.maxpool3d`` and ``torch.nn.functional.max_pool3d``
- Added Quantization Aware Training (QAT) workflow to contrib
- Added converter for ``torch.roll``
- Added converter for ``torch.nn.functional.layer_norm``
Expand Down
4 changes: 4 additions & 0 deletions docker/21-06/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3


RUN pip3 install termcolor
3 changes: 3 additions & 0 deletions docker/21-06/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

docker build -t torch2trt:21-06 -f $(pwd)/docker/21-06/Dockerfile .
4 changes: 4 additions & 0 deletions docker/21-06/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash


docker run --gpus all -it --rm -v $(pwd):/torch2trt torch2trt:21-06
4 changes: 4 additions & 0 deletions docker/21-09/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
FROM nvcr.io/nvidia/pytorch:21.09-py3


RUN pip3 install termcolor
3 changes: 3 additions & 0 deletions docker/21-09/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

docker build -t torch2trt:21-09 -f $(pwd)/docker/21-09/Dockerfile .
4 changes: 4 additions & 0 deletions docker/21-09/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash


docker run --gpus all -it --rm -v $(pwd):/torch2trt torch2trt:21-09
24 changes: 3 additions & 21 deletions examples/contrib/quantization_aware_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,9 @@ RUN add-apt-repository ppa:git-core/ppa && \

RUN pip install termcolor graphviz

## If you have followed instructions on main README.md file to install torch2trt using scripts/build_contrib.sh
## You dont require rest of the steps

RUN git clone https://github.com/NVIDIA/TensorRT.git /sw/TensorRT/

##Make sure that patch file is under the same folder where dockerfile is being called

ADD pytorch_nvidia_quantization.patch /sw/TensorRT

RUN cd /sw/TensorRT/ && \
git sparse-checkout init --cone && \
git sparse-checkout set /tools/pytorch-quantization/ && \
git apply --reject --whitespace=fix pytorch_nvidia_quantization.patch && \
cd tools/pytorch-quantization/ && \
python setup.py install

RUN git clone https://github.com/NVIDIA-AI-IOT/torch2trt.git /sw/TensorRT/ && \
cd /sw/TensorRT/ && \
git fetch origin pull/514/head:PR514 && \
git checkout PR514 && \
python setup.py install --plugins
RUN git clone https://github.com/NVIDIA-AI-IOT/torch2trt.git /sw/torch2trt/ && \
cd /sw/torch2trt/scripts && \
bash build_contrib.sh

```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def trt_lib_dir():

setup(
name='torch2trt',
version='0.2.0',
version='0.3.0',
description='An easy to use PyTorch to TensorRT converter',
packages=find_packages(exclude=exclude_dir),
ext_package='torch2trt',
Expand Down
48 changes: 48 additions & 0 deletions torch2trt/converters/AdaptiveAvgPool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter(
"torch.nn.AdaptiveAvgPool3d.forward", enabled=trt_version() >= "7.0"
)
def convert_AdaptiveAvgPool3d(ctx):
module = ctx.method_args[0]
input = ctx.method_args[1]
output = ctx.method_return

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]

output_size = module.output_size
if not isinstance(output_size, tuple):
output_size = (output_size,) * 3

stride = (
input_trt.shape[-3] // output_size[-3],
input_trt.shape[-2] // output_size[-2],
input_trt.shape[-1] // output_size[-1],
)

kernel_size = stride
layer = ctx.network.add_pooling_nd(
input=input_trt,
type=trt.PoolingType.AVERAGE,
window_size=kernel_size,
)
layer.stride_nd = stride

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_AdaptiveAvgPool3d_1x1x1():
return torch.nn.AdaptiveAvgPool3d((1, 1, 1))


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_AdaptiveAvgPool3d_2x2x2():
return torch.nn.AdaptiveAvgPool3d((2, 2, 2))


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_AdaptiveAvgPool3d_3x3x3():
return torch.nn.AdaptiveAvgPool3d((3, 3, 3))
23 changes: 23 additions & 0 deletions torch2trt/converters/BatchNorm3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter("torch.nn.BatchNorm3d.forward", enabled=trt_version() < "7.0")
def convert_BatchNorm3d(ctx):
module = ctx.method_args[0]
input = ctx.method_args[1]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return

scale = module.weight.detach().cpu().numpy() / np.sqrt(
module.running_var.detach().cpu().numpy() + module.eps
)
bias = (
module.bias.detach().cpu().numpy()
- module.running_mean.detach().cpu().numpy() * scale
)
power = np.ones_like(scale)

layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power)

output._trt = layer.get_output(0)
6 changes: 6 additions & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .AdaptiveAvgPool2d import *
from .BatchNorm1d import *
from .BatchNorm2d import *
from .clone import *
from .conv_functional import *
from .Conv import *
from .Conv1d import *
Expand All @@ -17,7 +18,9 @@
from .LogSoftmax import *
from .activation import *
from .adaptive_avg_pool2d import *
from .adaptive_avg_pool3d import *
from .adaptive_max_pool2d import *
from .adaptive_max_pool3d import *
from .add import *
from .avg_pool import *
from .batch_norm import *
Expand All @@ -35,8 +38,11 @@
from .instance_norm import *
from .interpolate import *
from .layer_norm import *
from .matmul import *
from .max import *
from .max_pool1d import *
from .max_pool2d import *
from .max_pool3d import *
from .mean import *
from .min import *
from .mod import *
Expand Down
11 changes: 11 additions & 0 deletions torch2trt/converters/adaptive_avg_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch2trt.torch2trt import *
from .AdaptiveAvgPool3d import *


@tensorrt_converter("torch.nn.functional.adaptive_avg_pool3d")
def convert_adaptive_avg_pool3d(ctx):
ctx.method_args = (
torch.nn.AdaptiveAvgPool3d(ctx.method_args[1]),
ctx.method_args[0],
)
convert_AdaptiveAvgPool3d(ctx)
41 changes: 41 additions & 0 deletions torch2trt/converters/adaptive_max_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter("torch.nn.functional.adaptive_max_pool3d")
def convert_adaptive_max_pool3d(ctx):
input = ctx.method_args[0]
output = ctx.method_return

output_size = ctx.method_args[1]
if isinstance(output_size, int):
output_size = (output_size,) * 3

stride = (
input._trt.shape[-3] // output_size[-3],
input._trt.shape[-2] // output_size[-2],
input._trt.shape[-1] // output_size[-1],
)

kernel_size = stride
layer = ctx.network.add_pooling_nd(
input=input._trt, type=trt.PoolingType.MAX, window_size=kernel_size
)
layer.stride_nd = stride

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_adaptive_max_pool3d_1x1x1():
return torch.nn.AdaptiveMaxPool3d((1, 1, 1))


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_adaptive_max_pool3d_2x2x2():
return torch.nn.AdaptiveMaxPool3d((2, 2, 2))


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 16, 224, 224)])
def test_adaptive_max_pool3d_3x3x3():
return torch.nn.AdaptiveMaxPool3d((3, 3, 3))
83 changes: 83 additions & 0 deletions torch2trt/converters/clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


def _set_layer_precision(ctx, layer):
# Supported TRT precisions as given by torch2trt_kwargs.
INT8_MODE = "int8_mode"
FP16_MODE = "fp16_mode"

# Check that args exist as expected in torch2trt_kwargs.
trt_kwargs = ctx.torch2trt_kwargs
assert INT8_MODE in trt_kwargs
assert FP16_MODE in trt_kwargs

is_int8 = trt_kwargs.get(INT8_MODE, False)
is_fp16 = trt_kwargs.get(FP16_MODE, False)

if is_int8:
layer.precision = trt.int8
layer.set_output_type(0, trt.int8)
elif is_fp16:
layer.precision = trt.float16
layer.set_output_type(0, trt.float16)


@tensorrt_converter('torch.clone')
@tensorrt_converter('torch.Tensor.clone')
def convert_clone(ctx):
input = ctx.method_args[0]
input_trt = trt_(ctx.network, input)

# Clone by making identity layer.
layer = ctx.network.add_identity(input_trt)
_set_layer_precision(ctx, layer)

output = ctx.method_return
output._trt = layer.get_output(0)


class Clone(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
def test_clone_basic():
return Clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
def test_clone_fp16_mode():
return Clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
def test_clone_int8_mode():
return Clone()


class TorchClone(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.clone(x)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
def test_torch_clone_basic():
return TorchClone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
def test_torch_clone_fp16_mode():
return TorchClone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
def test_torch_clone_int8_mode():
return TorchClone()
4 changes: 2 additions & 2 deletions torch2trt/converters/dummy_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def is_private(method):
method = method.split('.')[-1] # remove prefix
return method[0] == '_' and method[1] is not '_'
return method[0] == '_' and method[1] != '_'

def is_function_type(method):
fntype = eval(method + '.__class__.__name__')
Expand Down Expand Up @@ -34,4 +34,4 @@ def warn_method(ctx):
@tensorrt_converter('torch.Tensor.dim', is_real=False)
@tensorrt_converter('torch.Tensor.size', is_real=False)
def dont_warn(ctx):
pass
pass
Loading