24
24
25
25
26
26
class TorchTRTRuntimeStates :
27
- def __init__ (self , cudagraphs_enabled : bool , pre_allocated_outputs_enabled : bool ):
27
+ def __init__ (self , cudagraphs_enabled : bool ):
28
28
self .prev_cudagraphs_enabled = cudagraphs_enabled
29
- self .prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
29
+ self .prev_pre_allocated_outputs_enabled = False
30
+ self .has_context_changed = False
30
31
31
32
def validate_states (
32
33
self ,
33
34
cudagraphs_enabled : bool ,
34
35
pre_allocated_outputs_enabled : bool ,
35
36
shape_changed : bool ,
36
- ) -> Tuple [bool , bool ]:
37
- # Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
37
+ ) -> Tuple [bool , bool , bool ]:
38
+ # Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
38
39
# based on the current and previous states, as well as input shape has changed
39
40
need_cudagraphs_record = False
40
41
can_use_pre_allocated_outputs = False
42
+ need_cudagraphs_reset = False
41
43
42
44
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
43
- if cudagraphs_enabled and (not self .prev_cudagraphs_enabled or shape_changed ):
45
+ # If context is changed by runtime setting like weight streaming, it needs cuda graphs record
46
+ if cudagraphs_enabled and (
47
+ not self .prev_cudagraphs_enabled
48
+ or shape_changed
49
+ or self .has_context_changed
50
+ ):
44
51
need_cudagraphs_record = True
45
52
46
53
# Pre-allocated output can be used when previous and current state are true without shape change
@@ -51,10 +58,19 @@ def validate_states(
51
58
):
52
59
can_use_pre_allocated_outputs = True
53
60
61
+ if not cudagraphs_enabled or shape_changed or self .has_context_changed :
62
+ need_cudagraphs_reset = True
63
+
64
+ # Reset the flag
65
+ self .has_context_changed = False
54
66
self .prev_cudagraphs_enabled = cudagraphs_enabled
55
67
self .prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
56
68
57
- return need_cudagraphs_record , can_use_pre_allocated_outputs
69
+ return (
70
+ need_cudagraphs_record ,
71
+ need_cudagraphs_reset ,
72
+ can_use_pre_allocated_outputs ,
73
+ )
58
74
59
75
60
76
class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -141,15 +157,12 @@ def __init__(
141
157
self .engine = None
142
158
self .weight_name_map = weight_name_map
143
159
self .target_platform = Platform .current_platform ()
144
- < << << << HEAD
160
+
145
161
self .runtime_states = TorchTRTRuntimeStates (
146
- torch_tensorrt .runtime .get_cudagraphs_mode (), False
162
+ torch_tensorrt .runtime .get_cudagraphs_mode ()
147
163
)
148
164
self .pre_allocated_outputs : List [torch .Tensor ] = []
149
165
self .use_pre_allocated_outputs = False
150
- == == == =
151
- self .has_context_changed = False
152
- >> >> >> > 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
153
166
154
167
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
155
168
self .setup_engine ()
@@ -169,8 +182,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
169
182
del self .context
170
183
budget_bytes = self ._set_device_memory_budget (budget_bytes )
171
184
self .context = self .engine .create_execution_context ()
172
- # Indicates to reevaluate the runtime settings
173
- self .has_context_changed = True
185
+ self .runtime_states .has_context_changed = True
174
186
175
187
return budget_bytes
176
188
@@ -360,36 +372,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
360
372
self ._check_initialized ()
361
373
362
374
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
363
- < << << << HEAD
375
+
364
376
shape_changed = self .validate_input_shapes (inputs )
365
- need_cudagraphs_record , can_use_pre_allocated_outputs = (
366
- self .runtime_states .validate_states (
367
- cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
368
- )
369
- == == == =
370
- need_cudagraphs_record = cudagraphs_enabled and (
371
- not self .cudagraphs_validate_shapes (inputs ) or self .has_context_changed
372
- >> > >> >> 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
377
+ (
378
+ need_cudagraphs_record ,
379
+ need_cudagraphs_reset ,
380
+ can_use_pre_allocated_outputs ,
381
+ ) = self .runtime_states .validate_states (
382
+ cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
373
383
)
374
384
375
385
if need_cudagraphs_record :
376
- if self .cudagraph :
377
- self .cudagraph .reset ()
378
386
self ._input_buffers = [None ] * len (self .input_names )
379
387
self ._output_buffers = [None ] * len (self .output_names )
380
388
381
- if self . cudagraph and ( not cudagraphs_enabled or self .has_context_changed ) :
389
+ if need_cudagraphs_reset and self .cudagraph :
382
390
self .cudagraph .reset ()
383
391
self .cudagraph = None
384
392
385
- << << < << HEAD
386
393
# If in safe mode, check at each iteration for whether a switch is required
387
- == == == =
388
- # Reset the flag
389
- self .has_context_changed = False
390
-
391
- # If in safe mode, check at each iteration for for whether a switch is required
392
- >> >> > >> 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
393
394
if (
394
395
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
395
396
):
0 commit comments