Description
Problem Context
For certain TorchScript graph control flow blocks, the prim::RaiseException
primitive is used in one block to enforce an invariant, while the other block performs computation. One example of this is in the case of nn.Upsample
, which is shown below.
Upsample Graph Snippet
%out1.1 : Tensor = prim::If(%45)
block0():
%51 : Tensor = aten::upsample_bilinear2d(%X.1, %18, %119, %115)
-> (%51)
block1():
%53 : bool = aten::eq(%36, %24)
= prim::If(%53)
block0():
= prim::RaiseException(%12, %11)
-> ()
block1():
-> ()
%56 : bool = aten::eq(%36, %25)
= prim::If(%56)
block0():
= prim::RaiseException(%10, %11)
-> ()
block1():
-> ()
%59 : str = aten::format(%9, %36, %27)
= prim::RaiseException(%59, %11)
-> (%30)
In the graph above, the outermost block0
performs the aten::upsample_bilinear2d
computation, while the outermost block1
consists entirely of prim::RaiseException
calls to inform the user of dimensionality issues and other such occurrences. While helpful, our converter implementation of aten::upsample_bilinear2d
should handle dimension issues and report these to the user, instead of depending on the nn.Module
code to do so. As such, we can remove the prim::RaiseException
calls here.
Note further that there are many dangling prim::If
statements in the above code, which are never assigned to any variable. These seem difficult to remove, as node->destroy()
seems to segfault on these.
Desired Solution
The desired solution in this case is a lowering pass which detects whether a control-flow block has a guaranteed exception along one of the paths, and if so, eliminate the control flow entirely and replace the prim::If
with the nodes contained in the valid path. We already use torch::jit::EliminateExceptions
, however this pass only replaces the control flow boolean and not its logic, and seems to halt computation indefinitely in certain cases (see #1823). We also use the following lowering pass:
The above is a good starting point for a solution to this problem, but it does not fully solve the issue since it only tracks very specific instances of control flow logic containing exceptions.
Note
The lowering pass described above could be considered an "unsafe" lowering pass in the sense that it removes exceptions intended to catch anomalous cases. Torch-TensorRT currently has evaluator support for prim::RaiseException
operators. The above option could potentially be enabled via a compile-time flag, such as eliminate_exceptions=True
, which would improve code performance by removing exceptions.
Additional Context
For additional context, see #1823.