Skip to content

Commit a905728

Browse files
ToOutVarPass skips inplace ops
Differential Revision: D74833331 Pull Request resolved: #10921
1 parent 95e27ed commit a905728

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

exir/passes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
2525
from executorch.exir.error import InternalError
2626
from executorch.exir.operator.convert import (
27+
_get_overload_schema,
2728
get_out_args_from_opoverload,
2829
is_out_variant,
2930
to_out_variant,
@@ -63,6 +64,7 @@
6364
from torch._subclasses import FakeTensor
6465
from torch.fx.passes.infra.pass_base import PassBase, PassResult
6566
from torch.fx.passes.shape_prop import TensorMetadata
67+
from torchgen.model import SchemaKind
6668

6769
__all__ = [
6870
"ExportPass",
@@ -257,7 +259,6 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
257259
memory.alloc,
258260
memory.view,
259261
executorch_call_delegate,
260-
torch.ops.aten.copy_.default,
261262
}
262263
to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS)
263264

@@ -347,6 +348,8 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
347348
continue
348349
elif target in to_out_var_skiplist:
349350
continue
351+
elif _get_overload_schema(target).kind() == SchemaKind.inplace:
352+
continue
350353
if not isinstance(
351354
target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
352355
):

0 commit comments

Comments
 (0)