-
Notifications
You must be signed in to change notification settings - Fork 376
Description
Overview
The current way handling cuda stream is that every time we execute engine on a seperate stream and perform stream synchonization on every run. If there is a graph break, the synchonization is happening in every submodule, which is unnecessary. We are devicing a new way to handle the cuda stream.
We register two operators to torch that guard the stream:
-
torch.ops.tensorrt.enter_compute_stream_gaurd
- initialize a new stream if needed
- wait the default stream if needed
- set the current stream to this stream
-
torch.ops.tensorrt.exit_compute_stream_gaurd
- syn with the main stream
- set the stream back to the default torch stream
Detailed Explanation
1. Current stream is default stream
compiled_gm(x)In this situation, we have to create a different compute stream so that the TRT computing stream would not be on torch stream.
We should have the graph such that do following:
compiled_gm(x):
torch.ops.tensorrt.enter_compute_stream_gaurd:
stream = torch.cuda.Stream()
stream.wait(torch.cuda.default_stream()
torch.cuda.set_stream(stream)
torch.ops.tensorrt.execute_engine
torch.ops.tensorrt.exit_compute_stream_gaurd
stream = torch.cuda.current_stream()
torch.cuda.default_stream().wait_stream(stream)
torch.cuda.set_stream(torch.cuda.default_stream())
2. Current stream is not on default stream:
with torch.cuda.Stream() as s1:
compiled_gm(x)In this situation, we can do the operation in the current stream and don't need other manipulation:
compiled_gm(x):
torch.ops.tensorrt.enter_compute_stream_gaurd:
Do Nothing
torch.ops.tensorrt.execute_engine
torch.ops.tensorrt.exit_compute_stream_gaurd
Do Nothing
3. Input tensors come from different streams
In this situation, users are required to sync the streams against the main stream or any stream that they are going into run Torch-TensorRT module. The result stream handling would be the same as case 1 if users use the default stream and case 2 if users use a dedicated stream.
with torch.cuda.Stream() as s1:
j = a(i)
with torch.cuda.Stream() as s2:
k = b(i)
sync_stream(s1, s2)
...
compiled_gm(j, k)
Implementation detail
Graph Modification
The original graph looks like:
compiled_gm(x):
%2: tensor = torch.ops.tensorrt.execute_engine(x)
return %2
The graph after inserting ops like:
compiled_gm(x):
%0: bool = torch.ops.tensorrt.enter_compute_stream_gaurd()
%1: tensor = torch.ops.tensorrt.execute_engine(x)
%2: List[tensor] = torch.ops.tensorrt.exit_compute_stream_gaurd(%0)
return %1
graph(List[tensor]) -> List[tensor]
Ops registration
The ops are by default registered in C++ runtime. But if TORCHTRT_RUNTIME is not available (python only build) then we switch to python registration.
The implementation looks like:
# C++
Similar to what the code below in C++ syntax
# Python
from torch_tensorrt._features import EnabledFeatures
if not EnabledFeatures.TORCHTRT_RUNTIME:
@torch.library(tensorrt::enter_compute_stream)
def enter_compute_stream(x: List[torch.Tensor]) -> bool :
stream = torch.cuda.current_stream()
if stream == torch.cuda.default_stream():
new_stream = torch.cuda.Stream()
new_stream.wait_stream(torch.cuda.default_stream())
torch.cuda.set_stream(new_stream)
return True
return False
...
@torch.library(tensorrt::enter_compute_stream)
def exit_compute_stream(x: List[torch.Tensor], return_to_default: bool) -> List[Tensor]:
if return_to_default:
torch.cuda.default_stream().wait_stream(stream)
torch.cuda.set_stream(torch.cuda.default_stream())
return x
@torch.register_fake(tensorrt::enter_compute_stream)
def enter_compute_stream():
...
def exit_compute_stream():
...