Skip to content

✨[Feature] Torch stream handling for graph break #3977

@cehongwang

Description

@cehongwang

Overview

The current way handling cuda stream is that every time we execute engine on a seperate stream and perform stream synchonization on every run. If there is a graph break, the synchonization is happening in every submodule, which is unnecessary. We are devicing a new way to handle the cuda stream.

We register two operators to torch that guard the stream:

  • torch.ops.tensorrt.enter_compute_stream_gaurd

    • initialize a new stream if needed
    • wait the default stream if needed
    • set the current stream to this stream
  • torch.ops.tensorrt.exit_compute_stream_gaurd

    • syn with the main stream
    • set the stream back to the default torch stream

Detailed Explanation

1. Current stream is default stream

compiled_gm(x)

In this situation, we have to create a different compute stream so that the TRT computing stream would not be on torch stream.

We should have the graph such that do following:

compiled_gm(x):
    torch.ops.tensorrt.enter_compute_stream_gaurd:
        stream = torch.cuda.Stream()
        stream.wait(torch.cuda.default_stream()
        torch.cuda.set_stream(stream)
    torch.ops.tensorrt.execute_engine
    torch.ops.tensorrt.exit_compute_stream_gaurd
        stream = torch.cuda.current_stream()
        torch.cuda.default_stream().wait_stream(stream)
        torch.cuda.set_stream(torch.cuda.default_stream())

2. Current stream is not on default stream:

with torch.cuda.Stream() as s1:
    compiled_gm(x)

In this situation, we can do the operation in the current stream and don't need other manipulation:

compiled_gm(x):
    torch.ops.tensorrt.enter_compute_stream_gaurd:
        Do Nothing
    torch.ops.tensorrt.execute_engine
    torch.ops.tensorrt.exit_compute_stream_gaurd
        Do Nothing

3. Input tensors come from different streams

In this situation, users are required to sync the streams against the main stream or any stream that they are going into run Torch-TensorRT module. The result stream handling would be the same as case 1 if users use the default stream and case 2 if users use a dedicated stream.

with torch.cuda.Stream() as s1:
    j = a(i)
with torch.cuda.Stream() as s2:
    k = b(i)
    
sync_stream(s1, s2)
...
compiled_gm(j, k)

Implementation detail

Graph Modification

The original graph looks like:

compiled_gm(x):
    %2: tensor = torch.ops.tensorrt.execute_engine(x)
    return %2 

The graph after inserting ops like:

   
compiled_gm(x):
    %0: bool = torch.ops.tensorrt.enter_compute_stream_gaurd()    
    %1: tensor = torch.ops.tensorrt.execute_engine(x)
    %2: List[tensor]  = torch.ops.tensorrt.exit_compute_stream_gaurd(%0)
   return %1 

graph(List[tensor]) -> List[tensor]

Ops registration

The ops are by default registered in C++ runtime. But if TORCHTRT_RUNTIME is not available (python only build) then we switch to python registration.

The implementation looks like:

# C++
Similar to what the code below in C++ syntax
# Python
from torch_tensorrt._features import EnabledFeatures 

if not EnabledFeatures.TORCHTRT_RUNTIME: 
    @torch.library(tensorrt::enter_compute_stream)
    def enter_compute_stream(x: List[torch.Tensor]) -> bool :
        stream = torch.cuda.current_stream()
        if stream == torch.cuda.default_stream():
            new_stream = torch.cuda.Stream()
            new_stream.wait_stream(torch.cuda.default_stream())
            torch.cuda.set_stream(new_stream)
            return True
        return False
        ...
        
    @torch.library(tensorrt::enter_compute_stream)
    def exit_compute_stream(x: List[torch.Tensor], return_to_default: bool) -> List[Tensor]:
        if return_to_default:
            torch.cuda.default_stream().wait_stream(stream)
            torch.cuda.set_stream(torch.cuda.default_stream())
        return x

@torch.register_fake(tensorrt::enter_compute_stream)
def enter_compute_stream():
    ...
    
def exit_compute_stream():
    ...

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions