Skip to content

✨[Feature] Exclude key functions from tracing during AOT #1959

Closed
@gs-olive

Description

@gs-olive

Context

As of #1921, we can exclude modules from acceleration in AOT using the global module-level registry. Still, there are certain functions which are also decomposed as a result of aot_module_simplified:

return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(),
)

These functions include torch.functional.einsum, which is decomposed as follows:

##### THE MODEL
class SampleEinsum(torch.nn.Module):
    def forward(self, x):
        return torch.einsum('ii', x)

##### BEFORE AOT
graph():
    %l_x_ : torch.Tensor [#users=1] = placeholder[target=L_x_]
    %einsum : [#users=1] = call_function[target=torch.functional.einsum](args = (ii, %l_x_), kwargs = {})
    return (einsum,)

##### AFTER AOT
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %as_strided : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [4], [5], 0), kwargs = {})
    %permute : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%as_strided, [0]), kwargs = {})
    %permute_1 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute, [0]), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [0]), kwargs = {})
    return (sum_1,)

Feature Proposal

Augment the module-level replacement registry with a function-level replacement utility to replace TRT-convertible functions like einsum and exclude those from AOT tracing.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions