Skip to content

Commit 2e1d282

Browse files
committed
[pytorch/executorch][diff_train] [ez][release blocker fix] Insert linalg_vector_norm into decomp table used for Edge export (#9938)
## Context Addresses this [release blocker](https://github.com/orgs/pytorch/projects/99/views/1?pane=issue&itemId=104088363&issue=pytorch%7Cpytorch%7C150207) issue. Some models cannot export because they use `linalg_vector_norm` which is not currently an ATen operator. I initially tried adding the op to the core decomp table, but the decomp is not passing pytorch correctness tests. Please see pytorch/pytorch#150241 for more details. ## Changes Since we currently cannot include the op in PyTorch's decomp table, instead we can insert the op into the edge decomp table directly. This PR is a simple change to add `linalg_vector_norm` directly to the edge decomp table. Internal: << DO NOT EDIT BELOW THIS LINE >> **GitHub Author**: Sicheng Stephen Jia <ssjia@meta.com> (Meta Employee) **GitHub Repo**: [pytorch/executorch](https://github.com/pytorch/executorch) **GitHub Pull Request**: [#9938](#9938) Initially generated by: https://www.internalfb.com/intern/sandcastle/job/4503601401156690/ This was imported as part of a Diff Train. Please review this as soon as possible. Since it is a direct copy of a commit on GitHub, there shouldn't be much to do. diff-train-source-id: c2e3e17 Differential Revision: [D72729066](https://our.internmc.facebook.com/intern/diff/D72729066/) [ghstack-poisoned]
1 parent 1facfa9 commit 2e1d282

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

exir/program/test/test_program.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,17 +725,17 @@ def count_nodes(graph_module, target):
725725
)
726726

727727
def test_edge_dialect_non_core_aten_ops(self):
728-
class LinalgNorm(torch.nn.Module):
728+
class LinalgRank(torch.nn.Module):
729729
def __init__(self):
730730
super().__init__()
731731

732732
def forward(self, x: torch.Tensor) -> torch.Tensor:
733-
return torch.linalg.norm(x)
733+
return torch.linalg.matrix_rank(x)
734734

735735
from torch._export.verifier import SpecViolationError
736736

737-
input = torch.arange(9, dtype=torch.float) - 4
738-
ep = torch.export.export(LinalgNorm(), (input,), strict=True)
737+
input = torch.ones((9, 9, 9), dtype=torch.float)
738+
ep = torch.export.export(LinalgRank(), (input,), strict=True)
739739

740740
# aten::linalg_norm is not a core op, so it should error out
741741
with self.assertRaises(SpecViolationError):
@@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
748748
ep,
749749
compile_config=EdgeCompileConfig(
750750
_check_ir_validity=True,
751-
_core_aten_ops_exception_list=[
752-
torch.ops.aten.linalg_vector_norm.default
753-
],
751+
_core_aten_ops_exception_list=[torch.ops.aten._linalg_svd.default],
754752
),
755753
)
756754
except SpecViolationError:

exir/tracer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,18 @@ def _default_decomposition_table(
631631
]
632632
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633633
return get_decompositions(decomp_opset)
634+
635+
decomps = default_decompositions()
636+
# Add edge specific decompositions
637+
additional_decomp_ops = [
638+
# TODO: Eventually this op should be added to the core decompo table, and will not
639+
# need to be added here.
640+
torch.ops.aten.linalg_vector_norm.default,
641+
]
642+
additional_decomps = get_decompositions(additional_decomp_ops)
643+
decomps.update(additional_decomps)
634644
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
645+
return decomps
636646

637647

638648
def dynamo_trace(

0 commit comments

Comments
 (0)