Skip to content

Commit c4a09b0

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Create export_edge_to_executorch, call export_to_executorch in testing flow, and call print_ops_info in export_to_executorch (#3863)
Summary: Pull Request resolved: #3863 The current setup is unbalanced, especially because we would love to (and should) use the `print_ops_info` function better. Currently, it's not used in the calls most people would want to do on Bento for example. For one-liner compilation (e.g. calling `export_to_executorch`), the information should be printed. This diff refactors both the APIs in `__init__.py` and `utils.py`, so that the breakdown makes more sense. Arguably it should be a stack of diffs, but it mostly all goes hand in hand IMO. Main changes: - create an `export_edge_to_executorch` API, which takes in an `EdgeProgramManager`. This is useful because we want to keep the edge graph module around to pass it in `print_ops_count` - calls `print_ops_info` in `export_to_executorch` - call `export_to_executorch` in `run_and_verify`, using the exported module - introduce a `model_is_quantized()` util to call the right API when trying to make models eval. The check on the `GraphModule` type is not robust enough, since other models could be `GraphModule`s but not be quantized. If that's the case, we assert that they have been exported already, which makes the `eval()` requirement moot. Differential Revision: D58101124
1 parent 6554fa5 commit c4a09b0

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ python_library(
2929
],
3030
deps = [
3131
":passes",
32+
":utils",
3233
"//caffe2:torch",
3334
"//executorch/exir:lib",
3435
],

backends/cadence/aot/compiler.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
ReplaceScalarTensorWithFullPass,
1717
ReplaceSqueezeAndUnsqueezeWithViewPass,
1818
)
19+
from executorch.backends.cadence.aot.utils import model_is_quantized
1920
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
21+
from torch.ao.quantization.pt2e.export_utils import model_is_exported
2022

2123
from torch.export import export
2224
from torch.export.exported_program import ExportedProgram
@@ -29,14 +31,21 @@ def export_program(
2931
) -> ExportedProgram:
3032
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
3133

32-
# If the model is already a GraphModule (most likely from quantization), call the
33-
# suggested torch.ao.quantization API instead, which only does dropout and batchnorm.
34-
if isinstance(model, torch.fx.GraphModule):
35-
torch.ao.quantization.move_exported_model_to_eval(model)
36-
else:
37-
# We don't support training mode. Make it eval
34+
# We don't support training mode. Make the model inference mode by
35+
# calling model.eval() or an equivalent call for quantized models.
36+
# GraphModules cannot call eval(), so we skip them.
37+
if not isinstance(model, torch.fx.GraphModule):
3838
if hasattr(model, "eval"):
3939
model.eval()
40+
else:
41+
# If the model is quantized, call the suggested torch.ao.quantization API
42+
# which only does dropout and batchnorm.
43+
if model_is_quantized(model):
44+
torch.ao.quantization.move_exported_model_to_eval(model)
45+
else:
46+
# If we get a GraphModule which is _not_ quantized, then it should already
47+
# have been exported.
48+
assert model_is_exported(model), "model should be from an ExportedProgram"
4049

4150
# Prevent mkldnn decompositions
4251
torch._C._set_mkldnn_enabled(False)

backends/cadence/aot/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
1515
from tabulate import tabulate
1616

17+
quant_ops = {
18+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
19+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
20+
torch.ops.quantized_decomposed.quantize_per_channel.default,
21+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
22+
}
23+
24+
25+
# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
26+
def model_is_quantized(model: torch.nn.Module) -> bool:
27+
# Quantized models have to be GraphModules already, from prepare/convert calls.
28+
# Return false if the model is not a GraphModule.
29+
if not isinstance(model, torch.fx.GraphModule):
30+
return False
31+
32+
# Walk through the graph and look for quant/dequant ops
33+
for op in quant_ops:
34+
if model.graph.find_nodes(op="call_function", target=op):
35+
return True
36+
return False
37+
1738

1839
# Get the output size of a 1D convolution given the input size and parameters
1940
def get_conv1d_output_size(

0 commit comments

Comments
 (0)