Closed
Description
Bug Description
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.bf16: 10>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)
DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph:
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%l_x_,), kwargs = {})
return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering._repair_input_aliasing:Inserted auxiliary clone nodes for placeholders:
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering._remove_sym_nodes:Removed SymInt placeholders:
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
%_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {})
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %clone, %permute), kwargs = {})
return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {})
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %permute), kwargs = {})
return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
return (addmm,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
return (addmm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
WARNING:torch_tensorrt.dynamo._compiler:Node _param_constant1 of op type get_attr does not have metadata. This could sometimes lead to undefined behavior.
WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
Input shapes: [(128, 20)]
graph():
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
return addmm
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 12984, GPU 1045 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2657, GPU +308, now: CPU 15907, GPU 1353 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[128, 20], dtype=DataType.BF16]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node addmm (kind: aten.addmm.default, args: ('<torch.Tensor as np.ndarray [shape=(30,), dtype=float32]>', 'arg0_1 <tensorrt.ITensor [shape=(128, 20), dtype=DataType.BF16]>', '<torch.Tensor as np.ndarray [shape=(20, 30), dtype=float32]>'))
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Freezing tensor addmm_constant_0 to TRT IConstantLayer
Traceback (most recent call last):
File "C:\Users\HolyWu\Downloads\test.py", line 29, in <module>
optimized_model(*inputs)
File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\eval_frame.py", line 432, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 1115, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 947, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 471, in __call__
return _compile(
^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_utils_internal.py", line 83, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_strobelight\compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 816, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 635, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1184, in transform_code_object
transformations(instructions, code_options)
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 177, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 581, in transform
tracer.run()
File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2455, in run
super().run()
File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 897, in run
while self.step():
^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 809, in step
self.dispatch_table[inst.opcode](self, inst)
File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2646, in RETURN_VALUE
self._return(inst)
File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2631, in _return
self.output.compile_subgraph(
File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1097, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "C:\Python312\Lib\contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1314, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1405, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1386, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 128, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\__init__.py", line 1989, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 44, in torch_tensorrt_backend
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 52, in aot_torch_tensorrt_aten_backend
return _pretraced_backend(gm, sample_inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 108, in _pretraced_backend
trt_compiled = compile_module(
^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 412, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 106, in convert_module
interpreter_result = interpret_module_to_result(module, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 87, in interpret_module_to_result
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 310, in run
super().run()
File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 349, in run_node
trt_node: torch.fx.Node = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 457, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 469, in convert_with_type_enforcement
return func(ctx, target, new_args, new_kwargs, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\aten_ops_converters.py", line 2714, in aten_ops_addmm
return impl.addmm.addmm(
^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\addmm.py", line 24, in addmm
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\matmul.py", line 28, in matrix_multiply
other = get_trt_tensor(
^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 328, in get_trt_tensor
return create_constant(ctx, input_val, name, dtype, min_rank)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 287, in create_constant
value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python312\Lib\site-packages\torch_tensorrt\_enums.py", line 279, in to
raise TypeError("Unspported numpy dtype")
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
TypeError: Unspported numpy dtype
While executing %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x000001E4A3D7F930>: ((128, 20), torch.bfloat16, False, (20, 1), torch.contiguous_format, False, {})}})
Original traceback:
None
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
To Reproduce
import torch
import torch_tensorrt
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(20, 30)
def forward(self, x):
return self.linear(x)
device = torch.device("cuda", 0)
model = MyModule().eval().to(device).bfloat16()
inputs = [torch.randn((128, 20), dtype=torch.bfloat16, device=device)]
with torch.inference_mode():
optimized_model = torch_tensorrt.compile(
model,
ir="torch_compile",
inputs=inputs,
enabled_precisions={torch.bfloat16},
debug=True,
min_block_size=1,
device=device,
)
optimized_model(*inputs)
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.4.0.dev20240607+cu124
- PyTorch Version (e.g. 1.0): 2.4.0.dev20240607+cu124
- CPU Architecture: x64
- OS (e.g., Linux): Windows 11
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.12.3
- CUDA version: 12.4
- GPU models and configuration: RTX 3050
- Any other relevant information:
Additional context
Adding use_default=True
argument to to(np.dtype)
at