Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions docs/website/docs/tutorials/exporting_to_executorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,20 @@ class MyModule(torch.nn.Module):

aten_dialect = exir.capture(MyModule(), (torch.randn(3, 4),))

print(aten_dialect.exported_program)
print(aten_dialect)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 4], arg1_1: f32[5, 4], arg2_1: f32[5], arg3_1: f32[3, 4]):
add: f32[3, 4] = torch.ops.aten.add.Tensor(arg3_1, arg0_1);
permute: f32[4, 5] = torch.ops.aten.permute_copy.default(arg1_1, [1, 0]);
addmm: f32[3, 5] = torch.ops.aten.addmm.default(arg2_1, add, permute);
clamp: f32[3, 5] = torch.ops.aten.clamp.default(addmm, 0.0, 1.0);
return (clamp,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[4, 4]):
# File: /Users/marksaroufim/Dev/zzz/test3.py:10, code: return self.linear(x)
_param_constant0 = self._param_constant0
t: f32[4, 4] = torch.ops.aten.t.default(_param_constant0); _param_constant0 = None
_param_constant1 = self._param_constant1
addmm: f32[4, 4] = torch.ops.aten.addmm.default(_param_constant1, arg0_1, t); _param_constant1 = arg0_1 = t = None
return [addmm]

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=[], user_outputs=[], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
"""
```

Expand All @@ -106,18 +110,22 @@ This lowering will be done through the `to_edge()` API.

```python
aten_dialect = exir.capture(MyModule(), (torch.randn(3, 4),))
edge_dialect = aten_dialect.to_edge()
edge_dialect = aten_dialect.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
Copy link
Member Author

@msaroufim msaroufim Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this you get this error torch._export.verifier.SpecViolationError: Operator torch._ops.aten.t.default is not Aten Canonical. - should probably figure out how to make this error go away

I could for example get rid of this error by reworking the example to just do vector multiplication but like matmuls are probably more interesting lol https://gist.github.com/msaroufim/629b5c623fade2d5a30bec379f9e08da

Copy link
Contributor

@angelayi angelayi Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error should be fixed by D47346723 (which will be landed before PP) where aten.t will get decomposed to aten.permute which is ATen Canonical. We want to avoid users using the _check_ir_validity flag, but we should probably provide a better error message like "Please file an issue to executorch team, or turn on _check_ir_validity flag to unblock yourself for now"


print(edge_dialect.exported_program)
print(edge_dialect)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 4], arg1_1: f32[5, 4], arg2_1: f32[5], arg3_1: f32[3, 4]):
add: f32[3, 4] = executorch_exir_dialects_edge__ops_aten_add_Tensor(arg3_1, arg0_1);
permute: f32[4, 5] = executorch_exir_dialects_edge__ops_permute_copy_default(arg1_1, [1, 0]);
addmm: f32[3, 5] = executorch_exir_dialects_edge__ops_addmm_default(arg2_1, add, permute);
clamp: f32[3, 5] = executorch_exir_dialects_edge__ops_clamp_default(addmm, 0.0, 1.0);
return (clamp,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 3]):
# File: /Users/marksaroufim/Dev/zzz/test3.py:10, code: return self.linear(x)
_param_constant0: f32[3, 3] = self._param_constant0
t_copy_default: f32[3, 3] = torch.ops.aten.t_copy.default(_param_constant0); _param_constant0 = None
_param_constant1: f32[3] = self._param_constant1
addmm_default: f32[3, 3] = torch.ops.aten.addmm.default(_param_constant1, arg0_1, t_copy_default); _param_constant1 = arg0_1 = t_copy_default = None
return [addmm_default]

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=[], user_outputs=[], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
"""
```

Expand Down