Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions torch2trt/converters/Identity.py

This file was deleted.

11 changes: 0 additions & 11 deletions torch2trt/converters/ReLU.py

This file was deleted.

23 changes: 0 additions & 23 deletions torch2trt/converters/ReLU6.py

This file was deleted.

3 changes: 0 additions & 3 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
from .Conv2d import *
from .ConvTranspose import *
from .ConvTranspose2d import *
from .Identity import *
from .Linear import *
from .LogSoftmax import *
from .ReLU import *
from .ReLU6 import *
from .activation import *
from .adaptive_avg_pool2d import *
from .adaptive_max_pool2d import *
Expand Down
12 changes: 11 additions & 1 deletion torch2trt/converters/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
@tensorrt_converter('torch.nn.functional.dropout')
@tensorrt_converter('torch.nn.functional.dropout2d')
@tensorrt_converter('torch.nn.functional.dropout3d')
def convert_identity(ctx):
def convert_functional_identity(ctx):
input = ctx.method_args[0]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
output._trt = input_trt


@tensorrt_converter('torch.nn.Dropout.forward')
@tensorrt_converter('torch.nn.Dropout2d.forward')
@tensorrt_converter('torch.nn.Dropout3d.forward')
def convert_identity(ctx):
input = ctx.method_args[1]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
output._trt = input_trt
16 changes: 13 additions & 3 deletions torch2trt/converters/relu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from torch2trt.torch2trt import *
from .ReLU import *


@tensorrt_converter('torch.relu')
@tensorrt_converter('torch.relu_')
@tensorrt_converter('torch.nn.functional.relu')
@tensorrt_converter('torch.nn.functional.relu_')
def convert_relu(ctx):
def convert_functional_relu(ctx):
ctx.method_args = (torch.nn.ReLU(),) + ctx.method_args
convert_ReLU(ctx)
convert_relu(ctx)


@tensorrt_converter('torch.nn.ReLU.forward')
def convert_relu(ctx):
input = ctx.method_args[1]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
layer = ctx.network.add_activation(
input=input_trt, type=trt.ActivationType.RELU)
output._trt = layer.get_output(0)

28 changes: 25 additions & 3 deletions torch2trt/converters/relu6.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
from torch2trt.torch2trt import *
from .ReLU6 import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.relu6')
def convert_relu6(ctx):
def convert_functional_relu6(ctx):
ctx.method_args = (torch.nn.ReLU6(),) + ctx.method_args
convert_ReLU6(ctx)
convert_relu6(ctx)


@tensorrt_converter('torch.nn.ReLU6.forward')
def convert_relu6(ctx):
input = ctx.method_args[1]
output = ctx.method_return

input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input, 6])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)

layer = ctx.network.add_activation(
input=input_a_trt, type=trt.ActivationType.RELU)
layer = ctx.network.add_elementwise(
layer.get_output(0), input_b_trt, trt.ElementWiseOperation.MIN)

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_relu6_basic():
return torch.nn.ReLU6()