Closed
Description
Context
As of #1921, we can exclude modules from acceleration in AOT using the global module-level registry. Still, there are certain functions which are also decomposed as a result of aot_module_simplified
:
TensorRT/py/torch_tensorrt/dynamo/backend/backends.py
Lines 50 to 55 in e109049
These functions include torch.functional.einsum
, which is decomposed as follows:
##### THE MODEL
class SampleEinsum(torch.nn.Module):
def forward(self, x):
return torch.einsum('ii', x)
##### BEFORE AOT
graph():
%l_x_ : torch.Tensor [#users=1] = placeholder[target=L_x_]
%einsum : [#users=1] = call_function[target=torch.functional.einsum](args = (ii, %l_x_), kwargs = {})
return (einsum,)
##### AFTER AOT
graph():
%arg0_1 : [#users=1] = placeholder[target=arg0_1]
%as_strided : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [4], [5], 0), kwargs = {})
%permute : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%as_strided, [0]), kwargs = {})
%permute_1 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute, [0]), kwargs = {})
%sum_1 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [0]), kwargs = {})
return (sum_1,)
Feature Proposal
Augment the module-level replacement registry with a function-level replacement utility to replace TRT-convertible functions like einsum
and exclude those from AOT tracing.