Skip to content

🐛 [Bug] Cannot convert simple torchscript containing two torch.nn.Upsample operations #1823

Closed
@gcuendet

Description

@gcuendet

Bug Description

Scripting a simple "network" containing two torch.nn.Upsample modules and trying to convert the resulting torchscript does not work.

To Reproduce

Steps to reproduce the behavior:

  1. Generate a torchscript of the Network below, with torch.jit.script.
  2. Try to convert to TensorRT with torch_tensorrt (I tried both in C++, on linux and in python)
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        # NOTE:
        # * Specifying a single float as scale_factor or the 2d tuple doesn't change the behavior
        self.upsample1 = torch.nn.Upsample(
            scale_factor=(2.0, 2.0), mode="bilinear", align_corners=False
        )
        self.upsample2 = torch.nn.Upsample(
            scale_factor=(2.0, 2.0), mode="bilinear", align_corners=False
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # NOTE:
        # * Using the same self.upsample() module doesn't work as well
        # * Computing and returning out doesn't work as well
        out1 = self.upsample1(X)
        out2 = self.upsample2(X)
        # out = out1 + out2
        return out1

Expected behavior

The conversions succeeds and a new valid torchscript is obtained.

Environment

I managed to reproduce the bug both when using pytorch 1.11 and torch-tensorRT 1.1.0 and using pytorch 1.13.1 and torch-tensorRT main.

Torch-TensorRT 1.1.0

  • Torch-TensorRT Version (e.g. 1.0.0): 1.1.0
  • PyTorch Version (e.g. 1.0): 1.11
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux, Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRT
  • Build command you used (if compiling from source): Compiled with CMake
  • Python version: 3.8.16
  • CUDA version: 11.3
  • GPU models and configuration: NVIDIA GeForce RTX 2070 SUPER, SM Capability: 7.5
  • Any other relevant information:

When using torch-tensorRT 1.1.0, I get the following error:

DEBUG: [Torch-TensorRT] - Registering input/output torch::jit::Value for segmented graphs
terminate called after throwing an instance of 'c10::Error'
  what():  Expected Tensor but got Uninitialized


Exception raised from reportToTensorTypeError at /home/ubuntu/buildAgent/temp/buildTmp/conan_home/.conan/data/libtorch/1.11.0-5/cognex/stable/build/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/source_subfolder/aten/src/ATen/core/ivalue.cpp:908 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f66313e708b in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7f66313e2a5e in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #2: c10::IValue::reportToTensorTypeError() const + 0x64 (0x7f6634480064 in /mnt/caches/conan/data/libtorch/1.11.0-5/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libtorch_cpu.so)
frame #3: torch_tensorrt::core::partitioning::getSegmentsOutputByRunning(torch_tensorrt::core::partitioning::SegmentedBlock&, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x15a7 (0x7f6639fd8797 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #4: torch_tensorrt::core::partitioning::runShapeAnalysis(std::vector<torch_tensorrt::core::partitioning::SegmentedBlock, std::allocator<torch_tensorrt::core::partitioning::SegmentedBlock> >&, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x81 (0x7f6639fda581 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #5: torch_tensorrt::core::partitioning::Partition(torch::jit::Block*, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::partitioning::PartitionInfo const&) + 0x19b (0x7f6639fe587b in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #6: torch_tensorrt::core::ConstructFallbackGraph(torch::jit::Module&, torch::jit::Block*, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >, torch_tensorrt::core::CompileSpec, std::map<torch::jit::Value*, c10::IValue, std::less<torch::jit::Value*>, std::allocator<std::pair<torch::jit::Value* const, c10::IValue> > >) + 0xff (0x7f663a00006f in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #7: torch_tensorrt::core::CompileGraph(torch::jit::Module const&, torch_tensorrt::core::CompileSpec) + 0x980 (0x7f663a002a00 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #8: torch_tensorrt::torchscript::compile(torch::jit::Module const&, torch_tensorrt::torchscript::CompileSpec) + 0x5b7 (0x7f6639e9e7f7 in /mnt/caches/conan/data/torch-tensorrt/1.1.0-6/cognex/stable/package/c2bc9df050867ddb44d73bc998d3c8ceef4f763a/lib/libtorchtrt.so)
frame #9: <unknown function> + 0x5ec8 (0x56183fd18ec8 in ./build/src/interpolate_tensorrt)
frame #10: __libc_start_main + 0xf3 (0x7f65fe02f083 in /lib/x86_64-linux-gnu/libc.so.6)
frame #11: <unknown function> + 0x58ee (0x56183fd188ee in ./build/src/interpolate_tensorrt)

That looked kind of similar to this issue and patching Torch-TensorRT with this PR makes the behavior exactly the same as in the second case (i.e. when using pytorch 1.13.1 and torch-tensorRT main).

Torch-TensorRT main (commit 861edd0)

  • Torch-TensorRT Version (e.g. 1.0.0): main, commit 861edd0
  • PyTorch Version (e.g. 1.0): 1.13.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux, Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRT
  • Build command you used (if compiling from source): Compiled with CMake
  • Python version: 3.8.16
  • CUDA version: 11.8
  • GPU models and configuration: NVIDIA GeForce RTX 2070 SUPER, SM Capability: 7.5
  • Any other relevant information:

When using torch-TensorRT main, the conversion just hangs for ever after

GRAPH: [Torch-TensorRT] - Torch-TensorRT.TorchScript Graph Lowering

Additional context

Interestingly, when using the tracing mechanism of pytorch to generate the torchscript, everything seems fine (I didn't check the results, but the conversion finishes properly).
Also, when scripting with pytorch 1.9, everything works fine 🤯

The thing I noticed is that pytorch changed slightly the torch.nn.interpolate API and I am wondering if that could explain (at least partially) the problem:

  • Before pytorch 1.11: torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
  • PyTorch 1.11 and more recent: torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)

See the attached .zip file containing a python file to generate the torchscript.
upsample.zip

Let me know if you need more details to reproduce the problem. Thanks!

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions