Skip to content

Errors when lowering to edge. #3659

Closed
Closed
@ismaeelbashir03

Description

@ismaeelbashir03

I have attempted to lower a toy module with dynamic inputs, however I am facing some errors.

I have created the following module:

class Mul(torch.nn.Module):
    def __init__(self):
        super(Mul, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return torch.matmul(x, y.transpose(-2, -1))
    
    def get_eager_model(self) -> torch.nn.Module:
        return self

    def get_example_inputs(self):
        return (torch.randn(1, 3, 10), torch.randn(1, 3, 10))
    
    def get_dynamic_shapes(self):
        dim1_x = Dim("Dot_dim1_x", min=MIN_DIM, max=MAX_DIM)
        dim1_y = Dim("Dot_dim1_y", min=MIN_DIM, max=MAX_DIM)
        return {"x": {1: dim1_x, 2: dim1_y}, "y": {1: dim1_x, 2: dim1_y}}

I faced the following error when lower with to_edge():


raise InternalError(str(e)) from e
executorch.exir.error.InternalError: Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented:

  - tensor subclass <class 'torch._subclasses.fake_tensor.FakeTensor'>

For more information, try re-running with TORCH_LOGS=not_implemented

While executing %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%arg1_1, [0, 2, 1]), kwargs = {})
Original traceback:
line 45, in forward
    return torch.matmul(x, y.transpose(-2, -1))

I then tried doing the same without using transpose (as from the error code, I am assuming support for this has not yet been implemented) using the following module:

class Mul(torch.nn.Module):
    def __init__(self):
        super(Mul, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return torch.matmul(x, y)
    
    def get_eager_model(self) -> torch.nn.Module:
        return self

    def get_example_inputs(self):
        return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
    
    def get_dynamic_shapes(self):
        dim1_x = Dim("Dot_dim1_x", min=MIN_DIM, max=MAX_DIM)
        dim2_x = Dim("Dot_dim2_x", min=MIN_DIM, max=MAX_DIM)
        return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}

however I got this error:

raise InternalError(str(e)) from e
executorch.exir.error.InternalError: /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2291: SymIntArrayRef expected to contain only concrete integers

While executing %aten_expand_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.expand_copy.default](args = (%arg0_1, [1, %sym_size, %sym_size_1]), kwargs = {})
Original traceback:
line 45, in forward
    return torch.matmul(x, y)

Here is my lowering code:

def _to_core_aten(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    verbose=True,
) -> ExportedProgram:
    # post autograd export. eventually this will become .to_core_aten
    if not isinstance(model, torch.fx.GraphModule) and not isinstance(
        model, torch.nn.Module
    ):
        raise ValueError(
            f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
        )
    core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
    if verbose:
        logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
    return core_aten_ep


def _core_aten_to_edge(
    core_aten_exir_ep: ExportedProgram,
    edge_constant_methods: Optional[Dict[str, Any]] = None,
    edge_compile_config=None,
    verbose=True,
) -> EdgeProgramManager:
    if not edge_compile_config:
        edge_compile_config = exir.EdgeCompileConfig(
            _check_ir_validity=False,  # quant ops currently break ir verification
        )
    edge_manager: EdgeProgramManager = to_edge(
        core_aten_exir_ep,
        constant_methods=edge_constant_methods,
        compile_config=edge_compile_config,
    )
    if verbose:
        logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
    return edge_manager


def export_to_edge(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    edge_constant_methods: Optional[Dict[str, Any]] = None,
    edge_compile_config=_EDGE_COMPILE_CONFIG,
    verbose=True,
) -> EdgeProgramManager:
    core_aten_ep = _to_core_aten(model, example_inputs, dynamic_shapes, verbose=verbose)
    return _core_aten_to_edge(
        core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
    )

model = model.eval()
model = torch._export.capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shapes)

edge = export_to_edge(
      model,
      example_inputs,
      dynamic_shapes=dynamic_shapes,
      edge_compile_config=EdgeCompileConfig(
          _check_ir_validity=False if args.quantize else True,
      ),
  )

Lowering the same module with the shapes (1, 3) and (3, 1) works but when I added an extra dimension it seems to break. Any advice?

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: exirIssues related to Export IR and the code under exir/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions