Skip to content

Commit c211c98

Browse files
committed
chore: Proper logging message and rebase
1 parent 11886fe commit c211c98

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

py/torch_tensorrt/runtime/_cudagraphs.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Union
33

44
import torch
55
import torch_tensorrt
@@ -88,17 +88,15 @@ def __enter__(self) -> torch.nn.Module:
8888
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
8989

9090
logger.debug(
91-
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
91+
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
9292
)
9393
return CudaGraphsTorchTensorRTModule(self.compiled_module)
9494
else:
9595
if num_trt_module > 0:
96-
logger.debug(
97-
"There is no graph breaks. Using original module for cuda graphs"
98-
)
96+
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
9997
else:
100-
logger.warning(
101-
"Please consider dynamo if there is graph breaks. Using original module for cuda graphs"
98+
logger.debug(
99+
"Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode"
102100
)
103101
# Enable cudagraphs for TRT submodule
104102
set_cudagraphs_mode(True)
@@ -110,6 +108,6 @@ def __exit__(self, *args: Any) -> None:
110108

111109

112110
def enable_cudagraphs(
113-
compiled_module: torch.nn.Module,
111+
compiled_module: Union[torch.fx.GraphModule, torch.nn.Module],
114112
) -> _CudagraphsContextManager:
115113
return _CudagraphsContextManager(compiled_module)

py/torch_tensorrt/runtime/_weight_streaming.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Union
33

44
import torch
55
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
@@ -16,7 +16,7 @@ class _WeightStreamingContextManager(object):
1616
"""
1717

1818
def __init__(
19-
self, module: torch.fx.GraphModule | CudaGraphsTorchTensorRTModule
19+
self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
2020
) -> None:
2121
rt_mods = []
2222
self.current_device_budget = 0

0 commit comments

Comments
 (0)