Closed
Description
I have attempted to lower a toy module with dynamic inputs, however I am facing some errors.
I have created the following module:
class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.matmul(x, y.transpose(-2, -1))
def get_eager_model(self) -> torch.nn.Module:
return self
def get_example_inputs(self):
return (torch.randn(1, 3, 10), torch.randn(1, 3, 10))
def get_dynamic_shapes(self):
dim1_x = Dim("Dot_dim1_x", min=MIN_DIM, max=MAX_DIM)
dim1_y = Dim("Dot_dim1_y", min=MIN_DIM, max=MAX_DIM)
return {"x": {1: dim1_x, 2: dim1_y}, "y": {1: dim1_x, 2: dim1_y}}
I faced the following error when lower with to_edge()
:
raise InternalError(str(e)) from e
executorch.exir.error.InternalError: Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented:
- tensor subclass <class 'torch._subclasses.fake_tensor.FakeTensor'>
For more information, try re-running with TORCH_LOGS=not_implemented
While executing %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%arg1_1, [0, 2, 1]), kwargs = {})
Original traceback:
line 45, in forward
return torch.matmul(x, y.transpose(-2, -1))
I then tried doing the same without using transpose (as from the error code, I am assuming support for this has not yet been implemented) using the following module:
class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.matmul(x, y)
def get_eager_model(self) -> torch.nn.Module:
return self
def get_example_inputs(self):
return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
def get_dynamic_shapes(self):
dim1_x = Dim("Dot_dim1_x", min=MIN_DIM, max=MAX_DIM)
dim2_x = Dim("Dot_dim2_x", min=MIN_DIM, max=MAX_DIM)
return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}
however I got this error:
raise InternalError(str(e)) from e
executorch.exir.error.InternalError: /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2291: SymIntArrayRef expected to contain only concrete integers
While executing %aten_expand_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.expand_copy.default](args = (%arg0_1, [1, %sym_size, %sym_size_1]), kwargs = {})
Original traceback:
line 45, in forward
return torch.matmul(x, y)
Here is my lowering code:
def _to_core_aten(
model: Union[torch.fx.GraphModule, torch.nn.Module],
example_inputs: Tuple[Value, ...],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
verbose=True,
) -> ExportedProgram:
# post autograd export. eventually this will become .to_core_aten
if not isinstance(model, torch.fx.GraphModule) and not isinstance(
model, torch.nn.Module
):
raise ValueError(
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
)
core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
if verbose:
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
return core_aten_ep
def _core_aten_to_edge(
core_aten_exir_ep: ExportedProgram,
edge_constant_methods: Optional[Dict[str, Any]] = None,
edge_compile_config=None,
verbose=True,
) -> EdgeProgramManager:
if not edge_compile_config:
edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, # quant ops currently break ir verification
)
edge_manager: EdgeProgramManager = to_edge(
core_aten_exir_ep,
constant_methods=edge_constant_methods,
compile_config=edge_compile_config,
)
if verbose:
logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
return edge_manager
def export_to_edge(
model: Union[torch.fx.GraphModule, torch.nn.Module],
example_inputs: Tuple[Value, ...],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
edge_constant_methods: Optional[Dict[str, Any]] = None,
edge_compile_config=_EDGE_COMPILE_CONFIG,
verbose=True,
) -> EdgeProgramManager:
core_aten_ep = _to_core_aten(model, example_inputs, dynamic_shapes, verbose=verbose)
return _core_aten_to_edge(
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
)
model = model.eval()
model = torch._export.capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shapes)
edge = export_to_edge(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False if args.quantize else True,
),
)
Lowering the same module with the shapes (1, 3) and (3, 1) works but when I added an extra dimension it seems to break. Any advice?