Skip to content

Commit 151d6da

Browse files
SS-JIApytorchbot
authored andcommitted
[ez][release blocker fix] Insert linalg_vector_norm into decomp table used for Edge export (#9938)
Summary: ## Context Addresses this [release blocker](https://github.com/orgs/pytorch/projects/99/views/1?pane=issue&itemId=104088363&issue=pytorch%7Cpytorch%7C150207) issue. Some models cannot export because they use `linalg_vector_norm` which is not currently an ATen operator. I initially tried adding the op to the core decomp table, but the decomp is not passing pytorch correctness tests. Please see pytorch/pytorch#150241 for more details. ## Changes Since we currently cannot include the op in PyTorch's decomp table, instead we can insert the op into the edge decomp table directly. This PR is a simple change to add `linalg_vector_norm` directly to the edge decomp table. Test Plan: Tested exporting and running a model with the `linalg_vector_norm` op via the following script. ``` import torch from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from torch.export import Dim, export from executorch.extension.pybindings.portable_lib import ( # @Manual _load_for_executorch_from_buffer, ) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.linalg.vector_norm(x, 2) model = Model() inputs = (torch.randn(1,1,16,16),) dynamic_shapes = { "x": { 2: Dim("h", min=16, max=1024), 3: Dim("w", min=16, max=1024), } } exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() executorch_module = _load_for_executorch_from_buffer( executorch_program.buffer ) model_output = executorch_module.run_method( "forward", tuple(inputs) ) print(model_output) ``` (cherry picked from commit c2e3e17)
1 parent d53eff4 commit 151d6da

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

exir/program/test/test_program.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,17 +725,17 @@ def count_nodes(graph_module, target):
725725
)
726726

727727
def test_edge_dialect_non_core_aten_ops(self):
728-
class LinalgNorm(torch.nn.Module):
728+
class LinalgRank(torch.nn.Module):
729729
def __init__(self):
730730
super().__init__()
731731

732732
def forward(self, x: torch.Tensor) -> torch.Tensor:
733-
return torch.linalg.norm(x)
733+
return torch.linalg.matrix_rank(x)
734734

735735
from torch._export.verifier import SpecViolationError
736736

737-
input = torch.arange(9, dtype=torch.float) - 4
738-
ep = torch.export.export(LinalgNorm(), (input,), strict=True)
737+
input = torch.ones((9, 9, 9), dtype=torch.float)
738+
ep = torch.export.export(LinalgRank(), (input,), strict=True)
739739

740740
# aten::linalg_norm is not a core op, so it should error out
741741
with self.assertRaises(SpecViolationError):
@@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
748748
ep,
749749
compile_config=EdgeCompileConfig(
750750
_check_ir_validity=True,
751-
_core_aten_ops_exception_list=[
752-
torch.ops.aten.linalg_vector_norm.default
753-
],
751+
_core_aten_ops_exception_list=[torch.ops.aten._linalg_svd.default],
754752
),
755753
)
756754
except SpecViolationError:

exir/tracer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,18 @@ def _default_decomposition_table(
631631
]
632632
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633633
return get_decompositions(decomp_opset)
634+
635+
decomps = default_decompositions()
636+
# Add edge specific decompositions
637+
additional_decomp_ops = [
638+
# TODO: Eventually this op should be added to the core decompo table, and will not
639+
# need to be added here.
640+
torch.ops.aten.linalg_vector_norm.default,
641+
]
642+
additional_decomps = get_decompositions(additional_decomp_ops)
643+
decomps.update(additional_decomps)
634644
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
645+
return decomps
636646

637647

638648
def dynamo_trace(

0 commit comments

Comments
 (0)