|
| 1 | +from torch2trt.torch2trt import * |
| 2 | +from torch2trt.module_test import add_module_test |
| 3 | + |
| 4 | + |
| 5 | +def _set_layer_precision(ctx, layer): |
| 6 | + # Supported TRT precisions as given by torch2trt_kwargs. |
| 7 | + INT8_MODE = "int8_mode" |
| 8 | + FP16_MODE = "fp16_mode" |
| 9 | + |
| 10 | + # Check that args exist as expected in torch2trt_kwargs. |
| 11 | + trt_kwargs = ctx.torch2trt_kwargs |
| 12 | + assert INT8_MODE in trt_kwargs |
| 13 | + assert FP16_MODE in trt_kwargs |
| 14 | + |
| 15 | + is_int8 = trt_kwargs.get(INT8_MODE, False) |
| 16 | + is_fp16 = trt_kwargs.get(FP16_MODE, False) |
| 17 | + |
| 18 | + if is_int8: |
| 19 | + layer.precision = trt.int8 |
| 20 | + layer.set_output_type(0, trt.int8) |
| 21 | + elif is_fp16: |
| 22 | + layer.precision = trt.float16 |
| 23 | + layer.set_output_type(0, trt.float16) |
| 24 | + |
| 25 | + |
| 26 | +@tensorrt_converter('torch.clone') |
| 27 | +@tensorrt_converter('torch.Tensor.clone') |
| 28 | +def convert_clone(ctx): |
| 29 | + input = ctx.method_args[0] |
| 30 | + input_trt = trt_(ctx.network, input) |
| 31 | + |
| 32 | + # Clone by making identity layer. |
| 33 | + layer = ctx.network.add_identity(input_trt) |
| 34 | + _set_layer_precision(ctx, layer) |
| 35 | + |
| 36 | + output = ctx.method_return |
| 37 | + output._trt = layer.get_output(0) |
| 38 | + |
| 39 | + |
| 40 | +class Clone(torch.nn.Module): |
| 41 | + def __init__(self): |
| 42 | + super().__init__() |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + return x.clone() |
| 46 | + |
| 47 | + |
| 48 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)]) |
| 49 | +def test_clone_basic(): |
| 50 | + return Clone() |
| 51 | + |
| 52 | + |
| 53 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True) |
| 54 | +def test_clone_fp16_mode(): |
| 55 | + return Clone() |
| 56 | + |
| 57 | + |
| 58 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True) |
| 59 | +def test_clone_int8_mode(): |
| 60 | + return Clone() |
| 61 | + |
| 62 | + |
| 63 | +class TorchClone(torch.nn.Module): |
| 64 | + def __init__(self): |
| 65 | + super().__init__() |
| 66 | + |
| 67 | + def forward(self, x): |
| 68 | + return torch.clone(x) |
| 69 | + |
| 70 | + |
| 71 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)]) |
| 72 | +def test_torch_clone_basic(): |
| 73 | + return TorchClone() |
| 74 | + |
| 75 | + |
| 76 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], fp16_mode=True) |
| 77 | +def test_torch_clone_fp16_mode(): |
| 78 | + return TorchClone() |
| 79 | + |
| 80 | + |
| 81 | +@add_module_test(torch.float32, torch.device('cuda'), [(1, 64, 64)], int8_mode=True) |
| 82 | +def test_torch_clone_int8_mode(): |
| 83 | + return TorchClone() |
0 commit comments