Skip to content

✨[Feature] Add Lowering Pass to Eliminate If/Else Blocks with Exceptions in TorchScript #1842

Closed
@gs-olive

Description

@gs-olive

Problem Context

For certain TorchScript graph control flow blocks, the prim::RaiseException primitive is used in one block to enforce an invariant, while the other block performs computation. One example of this is in the case of nn.Upsample, which is shown below.

Upsample Graph Snippet
  %out1.1 : Tensor = prim::If(%45)
    block0():
      %51 : Tensor = aten::upsample_bilinear2d(%X.1, %18, %119, %115)
      -> (%51)
    block1():
      %53 : bool = aten::eq(%36, %24)
       = prim::If(%53)
        block0():
           = prim::RaiseException(%12, %11)
          -> ()
        block1():
          -> ()
      %56 : bool = aten::eq(%36, %25)
       = prim::If(%56)
        block0():
           = prim::RaiseException(%10, %11)
          -> ()
        block1():
          -> ()
      %59 : str = aten::format(%9, %36, %27)
       = prim::RaiseException(%59, %11)
      -> (%30)

In the graph above, the outermost block0 performs the aten::upsample_bilinear2d computation, while the outermost block1 consists entirely of prim::RaiseException calls to inform the user of dimensionality issues and other such occurrences. While helpful, our converter implementation of aten::upsample_bilinear2d should handle dimension issues and report these to the user, instead of depending on the nn.Module code to do so. As such, we can remove the prim::RaiseException calls here.

Note further that there are many dangling prim::If statements in the above code, which are never assigned to any variable. These seem difficult to remove, as node->destroy() seems to segfault on these.

Desired Solution

The desired solution in this case is a lowering pass which detects whether a control-flow block has a guaranteed exception along one of the paths, and if so, eliminate the control flow entirely and replace the prim::If with the nodes contained in the valid path. We already use torch::jit::EliminateExceptions, however this pass only replaces the control flow boolean and not its logic, and seems to halt computation indefinitely in certain cases (see #1823). We also use the following lowering pass:

void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {

The above is a good starting point for a solution to this problem, but it does not fully solve the issue since it only tracks very specific instances of control flow logic containing exceptions.

Note

The lowering pass described above could be considered an "unsafe" lowering pass in the sense that it removes exceptions intended to catch anomalous cases. Torch-TensorRT currently has evaluator support for prim::RaiseException operators. The above option could potentially be enabled via a compile-time flag, such as eliminate_exceptions=True, which would improve code performance by removing exceptions.

Additional Context

For additional context, see #1823.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions