Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .AdaptiveAvgPool2d import *
from .BatchNorm1d import *
from .BatchNorm2d import *
from .clone import *
from .conv_functional import *
from .Conv import *
from .Conv1d import *
Expand Down
83 changes: 83 additions & 0 deletions torch2trt/converters/clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


def _set_layer_precision(ctx, layer):
# Supported TRT precisions as given by torch2trt_kwargs.
INT8_MODE = "int8_mode"
FP16_MODE = "fp16_mode"

# Check that args exist as expected in torch2trt_kwargs.
trt_kwargs = ctx.torch2trt_kwargs
assert INT8_MODE in trt_kwargs
assert FP16_MODE in trt_kwargs

is_int8 = trt_kwargs.get(INT8_MODE, False)
is_fp16 = trt_kwargs.get(FP16_MODE, False)

if is_int8:
layer.precision = trt.int8
layer.set_output_type(0, trt.int8)
elif is_fp16:
layer.precision = trt.float16
layer.set_output_type(0, trt.float16)


@tensorrt_converter('torch.clone')
@tensorrt_converter('torch.Tensor.clone')
def convert_clone(ctx):
input = ctx.method_args[0]
input_trt = trt_(ctx.network, input)

# Clone by making identity layer.
layer = ctx.network.add_identity(input_trt)
_set_layer_precision(ctx, layer)

output = ctx.method_return
output._trt = layer.get_output(0)


class Clone(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
def test_clone_basic():
return Clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
def test_clone_fp16_mode():
return Clone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
def test_clone_int8_mode():
return Clone()


class TorchClone(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.clone(x)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
def test_torch_clone_basic():
return TorchClone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
def test_torch_clone_fp16_mode():
return TorchClone()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
def test_torch_clone_int8_mode():
return TorchClone()