Skip to content

Commit

Permalink
Fix quantized_matmul with 4D inputs (#4335)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4335

MobileBERT has a matmul with 4D inputs (`[1, 4, 8, 32]` by `[1, 4, 32, 8]`) which is erroring out AoT in the meta kernel.

This diff fixes the meta kernel to handle cases where the leading dimensions are more than one (the kernel itself can already handle it!).

Also move the exported graph dump to `export_program`, where it belongs. This prevents some double printing in some cases.

Note: this diff needs at GH approval!

Reviewed By: dulinriley, zonglinpengmeta

Differential Revision: D60050087

fbshipit-source-id: de09ed2fb9c5cdf729cc020119bf090d0f0c70c4
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 22, 2024
1 parent 844a69f commit f0364e8
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def quantize_pt2(
def export_program(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

Expand All @@ -99,7 +100,13 @@ def export_program(
torch._C._set_mkldnn_enabled(False)

# else: capture the model and return it.
return export(model, inputs)
expo_program = export(model, inputs)

if dump_graphs:
logging.info("Exported graph:")
expo_program.graph_module.graph.print_tabular()

return expo_program


# Export the model and lower it to an EdgeProgramManager (in edge IR).
Expand All @@ -111,11 +118,7 @@ def export_to_edge(
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs)

if dump_graphs:
logging.info("Exported graph:")
expo_program.graph_module.graph.print_tabular()
expo_program = export_program(model, inputs, dump_graphs=dump_graphs)

# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
Expand Down

0 comments on commit f0364e8

Please sign in to comment.