Skip to content

Commit

Permalink
Turn on python dispatcher for EdgeOpArgValidator (#3809)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3809

Possibly fixes #3659

We need to enable the python dispatcher so that expand_copy and view_copy will go through the correct meta kernels

Reviewed By: larryliu0820

Differential Revision: D58091304

fbshipit-source-id: f8907ee130720b01c629d55f222eb5a7e63a34bd
(cherry picked from commit ab6f177)
  • Loading branch information
angelayi committed Jun 3, 2024
1 parent 50d1da2 commit fd41791
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
36 changes: 34 additions & 2 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict

import torch
from executorch.exir import ExecutorchBackendConfig
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.error import ExportError
Expand All @@ -26,7 +26,7 @@
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from torch.export import export, ExportedProgram
from torch.export import Dim, export, ExportedProgram

from torch.library import impl, Library

Expand Down Expand Up @@ -225,6 +225,38 @@ def test_edge_manager_transform(self):
original_res, # x * y + x
)

def test_issue_3659(self):

class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.matmul(x, y)

def get_eager_model(self) -> torch.nn.Module:
return self

def get_example_inputs(self):
return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))

def get_dynamic_shapes(self):
dim1_x = Dim("Dot_dim1_x", min=2, max=100)
dim2_x = Dim("Dot_dim2_x", min=2, max=100)
return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}

model = Mul()
ep = torch.export.export(
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
)

to_edge(
ep,
compile_config=EdgeCompileConfig(
_check_ir_validity=True,
),
)

def test_transform_dict_api(self):
edge_manager = to_edge(get_exported_programs(), get_config_methods())

Expand Down
4 changes: 3 additions & 1 deletion exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EdgeOpArgValidator,
RunHigherOrderOperatorError,
)
from torch._dispatch.python import enable_python_dispatcher

from torch._export.verifier import SpecViolationError, Verifier
from torch._ops import OpOverload
Expand Down Expand Up @@ -119,7 +120,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
validator = EdgeOpArgValidator(gm)
inputs = _get_inputs(gm)
try:
validator.run(*inputs)
with enable_python_dispatcher():
validator.run(*inputs)
except RunHigherOrderOperatorError:
# NB: ignore higher order operator in the graph.
# If we lower a graph module to delegate and then compose it with some other graph module, retrace it,
Expand Down

0 comments on commit fd41791

Please sign in to comment.