Skip to content

Commit 2050887

Browse files
committed
chore: rebased from opt_out_buffer branch.
1 parent 7c5123a commit 2050887

File tree

5 files changed

+125
-80
lines changed

5 files changed

+125
-80
lines changed

core/runtime/TRTEngine.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
310310
if (profile_execution) {
311311
enable_profiling();
312312
}
313-
// Indicates to reevaluate the runtime settings
314-
has_context_changed = true;
313+
314+
runtime_states.set_context_changed();
315315

316316
return result;
317317
}

core/runtime/TRTEngine.h

+20-4
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,49 @@ using FlattenedState = std::tuple<
3232

3333
struct RuntimeStates {
3434
bool need_cudagraphs_record;
35+
bool need_cudagraphs_reset;
3536
bool can_use_pre_allocated_outputs;
3637
};
3738

3839
struct TorchTRTRuntimeStates {
3940
// Previous runtime states
40-
bool prev_cudagraphs_enabled, prev_pre_allocated_outputs_enabled;
41+
bool prev_cudagraphs_enabled = false;
42+
bool prev_pre_allocated_outputs_enabled = false;
43+
// Indicates to reevaluate the runtime settings as context has changed
44+
bool has_context_changed = false;
4145

42-
// Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
46+
// Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
4347
// based on the current and previous states, as well as input shape has changed
4448
RuntimeStates validate_states(bool cudagraphs_enabled, bool pre_allocated_outputs_enabled, bool shape_changed) {
4549
bool need_cudagraphs_record = false;
4650
bool can_use_pre_allocated_outputs = false;
51+
bool need_cudagraphs_reset = false;
4752

4853
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
49-
if (cudagraphs_enabled && (!prev_cudagraphs_enabled || shape_changed)) {
54+
// If context is changed by runtime setting like weight streaming, it needs cuda graphs record
55+
if (cudagraphs_enabled && (!prev_cudagraphs_enabled || shape_changed || has_context_changed)) {
5056
need_cudagraphs_record = true;
5157
}
5258
// Pre-allocated output can be used when previous and current state are true without shape change
5359
if (prev_pre_allocated_outputs_enabled && pre_allocated_outputs_enabled && !shape_changed) {
5460
can_use_pre_allocated_outputs = true;
5561
}
62+
63+
if (!cudagraphs_enabled || shape_changed || has_context_changed) {
64+
need_cudagraphs_reset = true;
65+
}
66+
67+
// Reset the flag
68+
has_context_changed = false;
5669
prev_cudagraphs_enabled = cudagraphs_enabled;
5770
prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled;
5871

59-
RuntimeStates values = {need_cudagraphs_record, can_use_pre_allocated_outputs};
72+
RuntimeStates values = {need_cudagraphs_record, need_cudagraphs_reset, can_use_pre_allocated_outputs};
6073
return values;
6174
}
75+
void set_context_changed() {
76+
has_context_changed = true;
77+
}
6278
};
6379

6480
struct TRTEngine : torch::CustomClassHolder {

core/runtime/execute_engine.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
205205
// Whether cudagraphs needs to record the graph on this pass
206206
RuntimeStates states = compiled_engine->runtime_states.validate_states(
207207
CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed);
208-
bool need_cudagraphs_record = states.need_cudagraphs_record;
209208

210-
if (!CUDAGRAPHS_MODE || shape_changed) {
209+
if (states.need_cudagraphs_reset) {
211210
compiled_engine->cudagraph.reset();
212211
}
213212

@@ -269,7 +268,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
269268
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
270269
}
271270

272-
setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);
271+
setup_input_tensors(inputs, compiled_engine, states.need_cudagraphs_record);
273272

274273
// Check if input shapes can be inferred.
275274
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
@@ -297,7 +296,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
297296
for (auto output_indices : compiled_engine->out_binding_map) {
298297
auto pyt_idx = output_indices.second;
299298
std::string name = compiled_engine->out_binding_names[pyt_idx];
300-
if (need_cudagraphs_record) {
299+
if (states.need_cudagraphs_record) {
301300
// If we are recording the cuda graph then we need to update the persistent output buffer
302301
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
303302
}
@@ -349,7 +348,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
349348
// Direct execution uses the caller buffers directly
350349
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
351350
} else {
352-
if (need_cudagraphs_record) {
351+
if (states.need_cudagraphs_record) {
353352
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
354353
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
355354
compiled_engine->cudagraph.capture_begin();

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+33-32
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,30 @@
2424

2525

2626
class TorchTRTRuntimeStates:
27-
def __init__(self, cudagraphs_enabled: bool, pre_allocated_outputs_enabled: bool):
27+
def __init__(self, cudagraphs_enabled: bool):
2828
self.prev_cudagraphs_enabled = cudagraphs_enabled
29-
self.prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
29+
self.prev_pre_allocated_outputs_enabled = False
30+
self.has_context_changed = False
3031

3132
def validate_states(
3233
self,
3334
cudagraphs_enabled: bool,
3435
pre_allocated_outputs_enabled: bool,
3536
shape_changed: bool,
36-
) -> Tuple[bool, bool]:
37-
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
37+
) -> Tuple[bool, bool, bool]:
38+
# Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
3839
# based on the current and previous states, as well as input shape has changed
3940
need_cudagraphs_record = False
4041
can_use_pre_allocated_outputs = False
42+
need_cudagraphs_reset = False
4143

4244
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
43-
if cudagraphs_enabled and (not self.prev_cudagraphs_enabled or shape_changed):
45+
# If context is changed by runtime setting like weight streaming, it needs cuda graphs record
46+
if cudagraphs_enabled and (
47+
not self.prev_cudagraphs_enabled
48+
or shape_changed
49+
or self.has_context_changed
50+
):
4451
need_cudagraphs_record = True
4552

4653
# Pre-allocated output can be used when previous and current state are true without shape change
@@ -51,10 +58,19 @@ def validate_states(
5158
):
5259
can_use_pre_allocated_outputs = True
5360

61+
if not cudagraphs_enabled or shape_changed or self.has_context_changed:
62+
need_cudagraphs_reset = True
63+
64+
# Reset the flag
65+
self.has_context_changed = False
5466
self.prev_cudagraphs_enabled = cudagraphs_enabled
5567
self.prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
5668

57-
return need_cudagraphs_record, can_use_pre_allocated_outputs
69+
return (
70+
need_cudagraphs_record,
71+
need_cudagraphs_reset,
72+
can_use_pre_allocated_outputs,
73+
)
5874

5975

6076
class PythonTorchTensorRTModule(Module): # type: ignore[misc]
@@ -141,15 +157,12 @@ def __init__(
141157
self.engine = None
142158
self.weight_name_map = weight_name_map
143159
self.target_platform = Platform.current_platform()
144-
<<<<<<< HEAD
160+
145161
self.runtime_states = TorchTRTRuntimeStates(
146-
torch_tensorrt.runtime.get_cudagraphs_mode(), False
162+
torch_tensorrt.runtime.get_cudagraphs_mode()
147163
)
148164
self.pre_allocated_outputs: List[torch.Tensor] = []
149165
self.use_pre_allocated_outputs = False
150-
=======
151-
self.has_context_changed = False
152-
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
153166

154167
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
155168
self.setup_engine()
@@ -169,8 +182,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
169182
del self.context
170183
budget_bytes = self._set_device_memory_budget(budget_bytes)
171184
self.context = self.engine.create_execution_context()
172-
# Indicates to reevaluate the runtime settings
173-
self.has_context_changed = True
185+
self.runtime_states.has_context_changed = True
174186

175187
return budget_bytes
176188

@@ -360,36 +372,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
360372
self._check_initialized()
361373

362374
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
363-
<<<<<<< HEAD
375+
364376
shape_changed = self.validate_input_shapes(inputs)
365-
need_cudagraphs_record, can_use_pre_allocated_outputs = (
366-
self.runtime_states.validate_states(
367-
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
368-
)
369-
=======
370-
need_cudagraphs_record = cudagraphs_enabled and (
371-
not self.cudagraphs_validate_shapes(inputs) or self.has_context_changed
372-
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
377+
(
378+
need_cudagraphs_record,
379+
need_cudagraphs_reset,
380+
can_use_pre_allocated_outputs,
381+
) = self.runtime_states.validate_states(
382+
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
373383
)
374384

375385
if need_cudagraphs_record:
376-
if self.cudagraph:
377-
self.cudagraph.reset()
378386
self._input_buffers = [None] * len(self.input_names)
379387
self._output_buffers = [None] * len(self.output_names)
380388

381-
if self.cudagraph and (not cudagraphs_enabled or self.has_context_changed):
389+
if need_cudagraphs_reset and self.cudagraph:
382390
self.cudagraph.reset()
383391
self.cudagraph = None
384392

385-
<<<<<<< HEAD
386393
# If in safe mode, check at each iteration for whether a switch is required
387-
=======
388-
# Reset the flag
389-
self.has_context_changed = False
390-
391-
# If in safe mode, check at each iteration for for whether a switch is required
392-
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
393394
if (
394395
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
395396
):

tests/py/dynamo/runtime/test_pre_allocated_outputs.py

+66-37
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import torch
24
import torch_tensorrt as torchtrt
35
from parameterized import parameterized
@@ -62,67 +64,94 @@ def forward(self, x):
6264
)
6365
def test_pre_allocated_outputs_dynamic(self, _, use_python_runtime):
6466
class SampleModel(torch.nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
self.layer1 = torch.nn.Linear(100, 128)
70+
self.layer2 = torch.nn.Linear(128, 64)
71+
self.relu = torch.nn.ReLU()
72+
6573
def forward(self, x):
66-
return torch.relu((x + 2) * 0.5)
74+
out = self.layer1(x)
75+
out = self.relu((out + 2.0) * 0.05)
76+
out = self.layer2(out)
77+
return out
6778

6879
inputs = torchtrt.Input(
69-
min_shape=(1, 3, 128, 224),
70-
opt_shape=(8, 3, 192, 224),
71-
max_shape=(16, 3, 224, 224),
80+
min_shape=(1, 100),
81+
opt_shape=(64, 100),
82+
max_shape=(128, 100),
7283
dtype=torch.float,
7384
name="x",
7485
)
75-
fx_graph = torch.fx.symbolic_trace(SampleModel())
86+
model = SampleModel().eval().cuda()
87+
fx_graph = torch.fx.symbolic_trace(model)
88+
89+
input_list = []
90+
input_list.append(torch.randn((8, 100)).cuda())
91+
input_list.append(torch.randn((12, 100)).cuda())
92+
input_list.append(torch.randn((12, 100)).cuda())
93+
input_list.append(torch.randn((8, 100)).cuda())
94+
input_list.append(torch.randn((8, 100)).cuda())
7695

7796
optimized_model = torchtrt.compile(
7897
fx_graph,
7998
"dynamo",
8099
inputs,
81100
min_block_size=1,
82101
pass_through_build_failures=True,
102+
use_explicit_typing=True,
103+
enable_weight_streaming=True,
83104
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
84105
use_python_runtime=use_python_runtime,
85106
)
86107

87-
input_list = []
88-
ref_out_list = []
89-
trt_out_list = []
90-
# Alternating cuda_graphs enable and input shapes at every five iterations.
91-
for i in [1, 3, 8, 11, 16]:
92-
for j in [128, 128, 222, 222, 224]:
93-
input_list.append(torch.randn((i, 3, j, 224)).cuda())
108+
# List of tuples representing different configurations for three features:
109+
# Cuda graphs, pre-allocated output buffer, weight streaming change
110+
states = list(itertools.product((True, False), repeat=3))
111+
# Create pairs of these configurations, representing an initial state and a changed state
112+
states_permutations = itertools.permutations(states, 2)
94113

95114
pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs(
96115
optimized_model
97116
)
98-
pre_allocated_output = False
99-
for enable_cuda_graphs in [False, True]:
100-
for i in range(len(input_list)):
101-
# Toggles cuda graph at all index in TRIALS
102-
if i % TRIALS == i // TRIALS:
103-
cuda_graphs = enable_cuda_graphs
104-
else:
105-
cuda_graphs = not enable_cuda_graphs
106-
if i % 3 == 0:
107-
pre_allocated_output = not pre_allocated_output
108-
117+
weight_streaming_ctx = torchtrt.runtime.weight_streaming(optimized_model)
118+
streamable_budget = weight_streaming_ctx.total_device_budget
119+
120+
for init_state, changed_state in states_permutations:
121+
for cuda_graphs, pre_allocated_output, weight_streaming in [
122+
init_state,
123+
changed_state,
124+
]:
109125
torchtrt.runtime.set_cudagraphs_mode(cuda_graphs)
110126
pre_allocated_output_ctx.set_pre_allocated_output(pre_allocated_output)
111127

112-
ref_out_list.append(fx_graph(input_list[i]))
113-
trt_out_list.append(optimized_model(input_list[i]))
114-
115-
for torch_model_results, optimized_model_results in zip(
116-
ref_out_list, trt_out_list
117-
):
118-
torch.testing.assert_close(
119-
torch_model_results,
120-
optimized_model_results,
121-
rtol=5e-03,
122-
atol=5e-03,
123-
equal_nan=True,
124-
check_dtype=True,
125-
)
128+
if weight_streaming:
129+
weight_streaming_ctx.device_budget = int(streamable_budget * 0.8)
130+
else:
131+
weight_streaming_ctx.device_budget = streamable_budget
132+
133+
ref_out_list = []
134+
trt_out_list = []
135+
# Input shape changes
136+
for i in range(len(input_list)):
137+
if weight_streaming and i == 4:
138+
weight_streaming_ctx.device_budget = int(
139+
streamable_budget * 0.6
140+
)
141+
ref_out_list.append(fx_graph(input_list[i]))
142+
trt_out_list.append(optimized_model(input_list[i]))
143+
144+
for torch_model_results, optimized_model_results in zip(
145+
ref_out_list, trt_out_list
146+
):
147+
torch.testing.assert_close(
148+
torch_model_results,
149+
optimized_model_results,
150+
rtol=5e-03,
151+
atol=5e-03,
152+
equal_nan=True,
153+
check_dtype=True,
154+
)
126155
torch._dynamo.reset()
127156

128157

0 commit comments

Comments
 (0)