diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 01de1f3bef..b7636e76a5 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -10,7 +10,7 @@ from typing import Any, Dict import torch -from executorch.exir import ExecutorchBackendConfig +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.error import ExportError @@ -26,7 +26,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from torch.export import export, ExportedProgram +from torch.export import Dim, export, ExportedProgram from torch.library import impl, Library @@ -225,6 +225,38 @@ def test_edge_manager_transform(self): original_res, # x * y + x ) + def test_issue_3659(self): + + class Mul(torch.nn.Module): + def __init__(self): + super(Mul, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.matmul(x, y) + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + return (torch.randn(1, 3, 10), torch.randn(1, 10, 3)) + + def get_dynamic_shapes(self): + dim1_x = Dim("Dot_dim1_x", min=2, max=100) + dim2_x = Dim("Dot_dim2_x", min=2, max=100) + return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}} + + model = Mul() + ep = torch.export.export( + model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes() + ) + + to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=True, + ), + ) + def test_transform_dict_api(self): edge_manager = to_edge(get_exported_programs(), get_config_methods()) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index d7b1aa5608..fe362824e4 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -17,6 +17,7 @@ EdgeOpArgValidator, RunHigherOrderOperatorError, ) +from torch._dispatch.python import enable_python_dispatcher from torch._export.verifier import SpecViolationError, Verifier from torch._ops import OpOverload @@ -119,7 +120,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: validator = EdgeOpArgValidator(gm) inputs = _get_inputs(gm) try: - validator.run(*inputs) + with enable_python_dispatcher(): + validator.run(*inputs) except RunHigherOrderOperatorError: # NB: ignore higher order operator in the graph. # If we lower a graph module to delegate and then compose it with some other graph module, retrace it,