Description
Bug Description
"RuntimeError: CUDA error: invalid argument" if cuda graphs is enabled and weight streaming budget has changed.
It seems cuda graphs need to record when weight streaming budget is changed
To Reproduce
model = SampleModel().eval().cuda()
input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()]
fx_graph = torch.fx.symbolic_trace(model)
optimized_model = torchtrt.compile(
fx_graph,
inputs=input,
ir="dynamo",
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
use_python_runtime=True,
use_explicit_typing=True,
enable_weight_streaming=True,
)
torchtrt.runtime.set_cudagraphs_mode(True)
Weight streaming context keeps current device budget size
with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
new_budget = int(weight_streaming_ctx.total_device_budget * 0.2)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)
new_budget = int(weight_streaming_ctx.total_device_budget * 0.4)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)
Expected behavior
no cuda runtime error
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: