Skip to content

✨[Feature] Improved parity between convert_method_to_trt_engine and require_full_compilation=True #1939

Closed
@gs-olive

Description

@gs-olive

Context

Currently in the TorchScript path, the require_full_compilation=True applies to conversion operators, or those performing some sort of mathematical operation or transformation. It does not apply to collection packing and unpacking operators. This can lead to confusion in later usage of convert_method_to_trt_engine, since some models which successfully compile with require_full_compilation=True, cannot be converted to fully-TRT engine objects with convert_method_to_trt_engine.

Proposed Solution

The above is not fully a problem in and of itself, since a model which outputs the schema (Tensor, (Tensor, Tensor)) could not be converted into a TRT-engine since its output is not a flat list of Tensors, but if its logic could be accelerated to a large extent by Torch-TensorRT, then this could reasonably be called "full compilation". Still, there are certain cases where the parity between convert_method_to_trt_engine and require_full_compilation could be improved - namely, the case of (Tensor, Tensor, ...) Tuple-type outputs. These cases can be addressed by removing the extraneous prim::TupleConstruct calls, and simply returning the IValues directly from the TRT engine.

**Note: ** A recent change to Torch/TorchScript prompted the above problem, as discussed in #1368, by changing the default behavior of traced multi-output-functions from TensorList to Tuple.

Example

import torch

class my_model(torch.nn.Module):
    def forward(self, x):
        a = 2 * x
        b = a + 2
        c = torch.relu(b)
        d = torch.sqrt(c)
        return c, d

On the above graph, torch currently by default considers the c, d output to be a Tuple, thereby prompting the insertion of a prim::TupleConstruct operator which is unnecessary. We would prefer if the following (lowered) graph instead were to have both IValues registered as outputs, so the entire graph could be converted to a TRT engine via torch_tensorrt.convert_method_to_trt_engine.

Lowered Graph: graph(%x.1 : Tensor):
  %3 : int = prim::Constant[value=1]()
  %2 : int = prim::Constant[value=2]()
  %a.1 : Tensor = aten::mul(%x.1, %2)
  %b.1 : Tensor = aten::add(%a.1, %2, %3)
  %c.1 : Tensor = aten::relu(%b.1)
  %d.1 : Tensor = aten::sqrt(%c.1)
  %8 : (Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1)
  return (%8)

##### PREFERRED BELOW #####

Lowered Graph: graph(%x.1 : Tensor):
  %3 : int = prim::Constant[value=1]()
  %2 : int = prim::Constant[value=2]()
  %a.1 : Tensor = aten::mul(%x.1, %2)
  %b.1 : Tensor = aten::add(%a.1, %2, %3)
  %c.1 : Tensor = aten::relu(%b.1)
  %d.1 : Tensor = aten::sqrt(%c.1)
  return (%c.1, %d.1)

Additional Context

Related to #1938

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions