Skip to content

Commit

Permalink
Create export_edge_to_executorch, call export_to_executorch in testin…
Browse files Browse the repository at this point in the history
…g flow, and call print_ops_info in export_to_executorch (#3863)

Summary:
Pull Request resolved: #3863

The current setup is unbalanced in the way we call different APIs, especially because we would love to (and should) use the `print_ops_info` function better. Currently, it's not used in the calls most people would want to do on Bento for example.

For one-liner compilation (e.g. calling `export_to_executorch`), the information should be printed. `export_to_executorch` should also be called in the testing flow.

This diff refactors both the APIs in `__init__.py` and `utils.py`, so that the breakdown makes more sense. Arguably it should be a stack of diffs, but it mostly all goes hand in hand IMO so I kept it as one.

Main changes:
- create an `export_edge_to_executorch` API, which takes in an `EdgeProgramManager`. This is useful because we want to keep the edge graph module around to pass it in `print_ops_count`, and now we can use it in `export_to_executorch` (see next point)
- calls `print_ops_info` in `export_to_executorch`, now that the edge graph is exposed there
- call `export_to_executorch` in `run_and_verify`, using the exported module. This required changing the checks for `eval()` mode, see next point.
- introduce a `model_is_quantized()` util to call the right API when trying to make models eval. The check on the `GraphModule` type is not robust enough, since other models could be `GraphModule`s but not be quantized. If that's the case, we assert that they have been exported already, which makes the `eval()` requirement moot.

Reviewed By: dulinriley, zonglinpengmeta

Differential Revision: D58101124

fbshipit-source-id: 9822411c2a832d539c96ab61aff99586da206d01
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jun 7, 2024
1 parent c18cea7 commit d928066
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ python_library(
],
deps = [
":passes",
":utils",
"//caffe2:torch",
"//executorch/exir:lib",
],
Expand Down
21 changes: 15 additions & 6 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from torch.ao.quantization.pt2e.export_utils import model_is_exported

from torch.export import export
from torch.export.exported_program import ExportedProgram
Expand All @@ -29,14 +31,21 @@ def export_program(
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# If the model is already a GraphModule (most likely from quantization), call the
# suggested torch.ao.quantization API instead, which only does dropout and batchnorm.
if isinstance(model, torch.fx.GraphModule):
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# We don't support training mode. Make it eval
# We don't support training mode. Make the model inference mode by
# calling model.eval() or an equivalent call for quantized models.
# GraphModules cannot call eval(), so we skip them.
if not isinstance(model, torch.fx.GraphModule):
if hasattr(model, "eval"):
model.eval()
else:
# If the model is quantized, call the suggested torch.ao.quantization API
# which only does dropout and batchnorm.
if model_is_quantized(model):
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# If we get a GraphModule which is _not_ quantized, then it should already
# have been exported.
assert model_is_exported(model), "model should be from an ExportedProgram"

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)
Expand Down
16 changes: 16 additions & 0 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from tabulate import tabulate

from torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops


# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
def model_is_quantized(model: torch.nn.Module) -> bool:
# Quantized models have to be GraphModules already, from prepare/convert calls.
# Return false if the model is not a GraphModule.
if not isinstance(model, torch.fx.GraphModule):
return False

# Walk through the graph and look for quant/dequant ops
for op in quant_ops:
if model.graph.find_nodes(op="call_function", target=op):
return True
return False


# Get the output size of a 1D convolution given the input size and parameters
def get_conv1d_output_size(
Expand Down

0 comments on commit d928066

Please sign in to comment.