Skip to content

🐛 [Bug] Error when weigh streaming and cuda graphs is used #3308

Closed
@keehyuna

Description

@keehyuna

Bug Description

"RuntimeError: CUDA error: invalid argument" if cuda graphs is enabled and weight streaming budget has changed.
It seems cuda graphs need to record when weight streaming budget is changed

To Reproduce

model = SampleModel().eval().cuda()
input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()]
fx_graph = torch.fx.symbolic_trace(model)

optimized_model = torchtrt.compile(
fx_graph,
inputs=input,
ir="dynamo",
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
use_python_runtime=True,
use_explicit_typing=True,
enable_weight_streaming=True,
)

torchtrt.runtime.set_cudagraphs_mode(True)

Weight streaming context keeps current device budget size

with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
new_budget = int(weight_streaming_ctx.total_device_budget * 0.2)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)

new_budget = int(weight_streaming_ctx.total_device_budget * 0.4)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)

Expected behavior

no cuda runtime error

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions