Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨[Feature] Add Lowering Pass to Eliminate If/Else Blocks with Exceptions in TorchScript #1842

Closed
gs-olive opened this issue Apr 20, 2023 · 2 comments
Assignees
Labels
feature request New feature or request

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Apr 20, 2023

Problem Context

For certain TorchScript graph control flow blocks, the prim::RaiseException primitive is used in one block to enforce an invariant, while the other block performs computation. One example of this is in the case of nn.Upsample, which is shown below.

Upsample Graph Snippet
  %out1.1 : Tensor = prim::If(%45)
    block0():
      %51 : Tensor = aten::upsample_bilinear2d(%X.1, %18, %119, %115)
      -> (%51)
    block1():
      %53 : bool = aten::eq(%36, %24)
       = prim::If(%53)
        block0():
           = prim::RaiseException(%12, %11)
          -> ()
        block1():
          -> ()
      %56 : bool = aten::eq(%36, %25)
       = prim::If(%56)
        block0():
           = prim::RaiseException(%10, %11)
          -> ()
        block1():
          -> ()
      %59 : str = aten::format(%9, %36, %27)
       = prim::RaiseException(%59, %11)
      -> (%30)

In the graph above, the outermost block0 performs the aten::upsample_bilinear2d computation, while the outermost block1 consists entirely of prim::RaiseException calls to inform the user of dimensionality issues and other such occurrences. While helpful, our converter implementation of aten::upsample_bilinear2d should handle dimension issues and report these to the user, instead of depending on the nn.Module code to do so. As such, we can remove the prim::RaiseException calls here.

Note further that there are many dangling prim::If statements in the above code, which are never assigned to any variable. These seem difficult to remove, as node->destroy() seems to segfault on these.

Desired Solution

The desired solution in this case is a lowering pass which detects whether a control-flow block has a guaranteed exception along one of the paths, and if so, eliminate the control flow entirely and replace the prim::If with the nodes contained in the valid path. We already use torch::jit::EliminateExceptions, however this pass only replaces the control flow boolean and not its logic, and seems to halt computation indefinitely in certain cases (see #1823). We also use the following lowering pass:

void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {

The above is a good starting point for a solution to this problem, but it does not fully solve the issue since it only tracks very specific instances of control flow logic containing exceptions.

Note

The lowering pass described above could be considered an "unsafe" lowering pass in the sense that it removes exceptions intended to catch anomalous cases. Torch-TensorRT currently has evaluator support for prim::RaiseException operators. The above option could potentially be enabled via a compile-time flag, such as eliminate_exceptions=True, which would improve code performance by removing exceptions.

Additional Context

For additional context, see #1823.

@bowang007
Copy link
Collaborator

looks like related to this issue as well: #1357

gcuendet pushed a commit to gcuendet/Torch-TensorRT that referenced this issue May 8, 2023
Following [this issue](pytorch#1823),
and [this proposal](pytorch#1842),
this commit implements the proposal in a way that is extremely specific
to Upsample resulting in upsample_bilinear2d.
gcuendet pushed a commit to gcuendet/Torch-TensorRT that referenced this issue May 8, 2023
Following [this issue](pytorch#1823),
and [this proposal](pytorch#1842),
this commit implements the proposal in a way that is extremely specific
to Upsample resulting in upsample_bilinear2d.
gcuendet pushed a commit to gcuendet/Torch-TensorRT that referenced this issue May 15, 2023
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842),
this commit implements the proposal in a way that is extremely specific
to Upsample resulting in upsample_bilinear2d.
gcuendet pushed a commit to gcuendet/Torch-TensorRT that referenced this issue May 15, 2023
Following [this issue](pytorch#1823), and [this proposal](pytorch#1842),
this commit implements the proposal in a way that is extremely specific
to Upsample resulting in upsample_bilinear2d.
@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants