Description
❓ Question
I'm trying to compile a model with dynamic input shape but told that the function torch._ops.aten.aten::_to_copy
is not currently supported:
File "/home/wh/generative_action/SynHSI/test_module.py", line 325, in <module>
model = torch_tensorrt.compile(model, ir="dynamo", inputs=trt_inputs)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 249, in compile
trt_graph_module = dynamo_compile(
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 243, in compile
trt_gm = compile_module(gm, inputs, settings)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 431, in compile_module
trt_module = convert_module(
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 107, in convert_module
interpreter_result = interpret_module_to_result(module, inputs, settings)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 336, in run
self._construct_trt_network_def()
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 317, in _construct_trt_network_def
super().run()
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch/fx/interpreter.py", line 147, in run
self.env[node] = self.run_node(node)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 378, in run_node
trt_node: torch.fx.Node = super().run_node(n)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch/fx/interpreter.py", line 204, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/wh/miniconda3/envs/hsi-torch-dev/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 480, in call_function
raise UnsupportedOperatorException(
torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::_to_copy not currently supported!
the code caused this error is as follow:
pi = self.positional_encoder.pos_encoding[pi.long()]
where the self.positional_encoder
is an instance of a customized implementation of the transformer position encoder:
class PositionalEncoding(nn.Module):
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
# Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# max_len determines how far the position can have an effect on a token (window)
# Info
self.dropout = nn.Dropout(dropout_p)
# Encoding - From formula
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5
division_term = torch.exp(
torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
# Saving buffer (same as parameter without gradients needed)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
def forward(self, token_embedding: torch.tensor) -> torch.tensor:
# Residual connection + pos encoding
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
What you have already tried
The complete model is complicated so I have tried to implement a minimal reproducible example, but the compilation of a single PositionalEncoding
model succeed. I also tried adding more context code but it still succeed. I'm unable to get a minimal reproducible example now.
I found this error only occurs with dynamic input shape. Compiling model with fixed input shape works well.
Besides, I noticed that #2161 had added the _to_copy
converter, so I'm confused why it told me _to_copy
is not supported, or maybe I misunderstand something?
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- PyTorch Version (e.g., 1.0): 2.5.0.dev20240804+cu118
- CPU Architecture: x86_64
- OS (e.g., Linux): ubuntu 20.04
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Python version: 3.10
- CUDA version: 11.8
- torch_tensorrt: 2.5.0.dev20240804+cu124