Skip to content

fix: Record cudagraphs when weight streaming budget has changed #3309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ TRTEngine::TRTEngine(

runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
runtime_states.old_pre_allocated_outputs = false;
runtime_states.context_changed = false;

if (_in_binding_names.size() == 0 && _out_binding_names.size() == 0) {
uint64_t inputs = 0;
Expand Down Expand Up @@ -310,6 +311,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
if (profile_execution) {
enable_profiling();
}
// Indicates to reevaluate the runtime settings
runtime_states.context_changed = true;

return result;
}

Expand Down
20 changes: 16 additions & 4 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,37 @@ struct TorchTRTRuntimeStates {
bool old_cudagraphs;
// Indicates whether pre-allocated output was enabled in the previous execute_engine
bool old_pre_allocated_outputs;
// Indicates whether context has changed
bool context_changed;

// Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
// Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
// based on the current and previous states, as well as input shape has changed
std::tuple<bool, bool> set_runtime_states(bool new_cudagraphs, bool new_pre_allocated_output, bool shape_changed) {
std::tuple<bool, bool, bool> set_runtime_states(
bool new_cudagraphs,
bool new_pre_allocated_output,
bool shape_changed) {
bool need_cudagraphs_record = false;
bool can_use_pre_allocated_outputs = false;
bool need_cudagraphs_reset = false;

// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
if (new_cudagraphs && (!old_cudagraphs || shape_changed)) {
if (new_cudagraphs && (!old_cudagraphs || shape_changed || context_changed)) {
need_cudagraphs_record = true;
}
// Pre-allocated output can be used when previous and current state are true without shape change
if (old_pre_allocated_outputs && new_pre_allocated_output && !shape_changed) {
can_use_pre_allocated_outputs = true;
}
if (!new_cudagraphs || shape_changed || context_changed) {
need_cudagraphs_reset = true;
}

old_cudagraphs = new_cudagraphs;
old_pre_allocated_outputs = new_pre_allocated_output;
// Reset flag
context_changed = false;

return {need_cudagraphs_record, can_use_pre_allocated_outputs};
return {need_cudagraphs_record, can_use_pre_allocated_outputs, need_cudagraphs_reset};
}
};

Expand Down
3 changes: 2 additions & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

bool need_cudagraphs_record = std::get<0>(result);
bool can_use_pre_allocated_outputs = std::get<1>(result);
bool need_cudagraphs_reset = std::get<2>(result);

if (!cudagraphs_enabled || shape_changed) {
if (need_cudagraphs_reset) {
compiled_engine->cudagraph.reset();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def __init__(
super(CudaGraphsTorchTensorRTModule, self).__init__()
self.compiled_module = compiled_module
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
self.is_weight_streaming_set = False

self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.shape_key: Optional[str] = None
self.prev_cudagraphs_enabled = False
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
self.warm_up()
Expand Down Expand Up @@ -77,15 +77,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed
self.prev_cudagraphs_enabled = cudagraphs_enabled

need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.inputs)

self.is_weight_streaming_set = False
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(
Expand Down
57 changes: 37 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,33 @@


class TorchTRTRuntimeStates:
def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
def __init__(self, new_cudagraphs: bool):
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
self.old_cudagraphs = new_cudagraphs
# Indicates whether pre-allocated output was enabled in the previous execute_engine
self.old_pre_allocated_outputs = new_pre_allocated_output
self.old_pre_allocated_outputs = False
# Indicates whether context has changed
self.context_changed = False

def set_runtime_states(
self,
new_cudagraphs: bool,
new_pre_allocated_output: bool,
shape_changed: bool,
) -> Tuple[bool, bool]:
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
) -> Tuple[bool, bool, bool]:
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs
# based on the current and previous states, as well as input shape has changed
need_cudagraphs_record = False
can_use_pre_allocated_outputs = False

# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
if new_cudagraphs and (not self.old_cudagraphs or shape_changed):
need_cudagraphs_reset = False

# CUDA Graph recording is needed if CUDA graphs is enabled and:
# - CUDA graphs were previously disabled
# - or the shape has changed
# - or the execution context has changed (e.g., weight streaming)
if new_cudagraphs and (
not self.old_cudagraphs or shape_changed or self.context_changed
):
need_cudagraphs_record = True

# Pre-allocated output can be used when previous and current state are true without shape change
Expand All @@ -53,10 +61,19 @@ def set_runtime_states(
):
can_use_pre_allocated_outputs = True

if not new_cudagraphs or shape_changed or self.context_changed:
need_cudagraphs_reset = True

self.old_cudagraphs = new_cudagraphs
self.old_pre_allocated_outputs = new_pre_allocated_output
# reset flag
self.context_changed = False

return need_cudagraphs_record, can_use_pre_allocated_outputs
return (
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
)


class PythonTorchTensorRTModule(Module): # type: ignore[misc]
Expand Down Expand Up @@ -145,7 +162,7 @@ def __init__(
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.runtime_states = TorchTRTRuntimeStates(
torch_tensorrt.runtime.get_cudagraphs_mode(), False
torch_tensorrt.runtime.get_cudagraphs_mode()
)
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = False
Expand All @@ -168,6 +185,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
del self.context
budget_bytes = self._set_device_memory_budget(budget_bytes)
self.context = self.engine.create_execution_context()
self.runtime_states.context_changed = True
return budget_bytes

def _set_device_memory_budget(self, budget_bytes: int) -> int:
Expand Down Expand Up @@ -200,7 +218,6 @@ def setup_engine(self) -> None:
if self.settings.enable_weight_streaming:
self.set_default_device_memory_budget()
self.context = self.engine.create_execution_context()

assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)
Expand Down Expand Up @@ -356,22 +373,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record, can_use_pre_allocated_outputs = (
self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)
(
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
) = self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

if not cudagraphs_enabled and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

# If in safe mode, check at each iteration for whether a switch is required
if (
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def __init__(
) -> None:
rt_mods = []
self.current_device_budget = 0
self.cuda_graphs_module = None

if isinstance(module, CudaGraphsTorchTensorRTModule):
self.cuda_graphs_module = module
module = module.compiled_module
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
Expand Down Expand Up @@ -78,6 +80,8 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
Loading
Loading