Skip to content

Commit d2ebdaf

Browse files
authored
Merge pull request #633 from chaoz-dev/chaoz-dev/converters-clone
Add torch.clone to converters.
2 parents 15b7f89 + 572d422 commit d2ebdaf

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .AdaptiveAvgPool2d import *
88
from .BatchNorm1d import *
99
from .BatchNorm2d import *
10+
from .clone import *
1011
from .conv_functional import *
1112
from .Conv import *
1213
from .Conv1d import *

torch2trt/converters/clone.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
def _set_layer_precision(ctx, layer):
6+
# Supported TRT precisions as given by torch2trt_kwargs.
7+
INT8_MODE = "int8_mode"
8+
FP16_MODE = "fp16_mode"
9+
10+
# Check that args exist as expected in torch2trt_kwargs.
11+
trt_kwargs = ctx.torch2trt_kwargs
12+
assert INT8_MODE in trt_kwargs
13+
assert FP16_MODE in trt_kwargs
14+
15+
is_int8 = trt_kwargs.get(INT8_MODE, False)
16+
is_fp16 = trt_kwargs.get(FP16_MODE, False)
17+
18+
if is_int8:
19+
layer.precision = trt.int8
20+
layer.set_output_type(0, trt.int8)
21+
elif is_fp16:
22+
layer.precision = trt.float16
23+
layer.set_output_type(0, trt.float16)
24+
25+
26+
@tensorrt_converter('torch.clone')
27+
@tensorrt_converter('torch.Tensor.clone')
28+
def convert_clone(ctx):
29+
input = ctx.method_args[0]
30+
input_trt = trt_(ctx.network, input)
31+
32+
# Clone by making identity layer.
33+
layer = ctx.network.add_identity(input_trt)
34+
_set_layer_precision(ctx, layer)
35+
36+
output = ctx.method_return
37+
output._trt = layer.get_output(0)
38+
39+
40+
class Clone(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
def forward(self, x):
45+
return x.clone()
46+
47+
48+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
49+
def test_clone_basic():
50+
return Clone()
51+
52+
53+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
54+
def test_clone_fp16_mode():
55+
return Clone()
56+
57+
58+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
59+
def test_clone_int8_mode():
60+
return Clone()
61+
62+
63+
class TorchClone(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
def forward(self, x):
68+
return torch.clone(x)
69+
70+
71+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)])
72+
def test_torch_clone_basic():
73+
return TorchClone()
74+
75+
76+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True)
77+
def test_torch_clone_fp16_mode():
78+
return TorchClone()
79+
80+
81+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True)
82+
def test_torch_clone_int8_mode():
83+
return TorchClone()

0 commit comments

Comments
 (0)