Skip to content

Commit 9a0b51c

Browse files
authored
Make the quantized path the main testing path, and introduce a nop quantizer for fp32 cases
Differential Revision: D67561806 Pull Request resolved: #7915
1 parent 7600f18 commit 9a0b51c

File tree

3 files changed

+15
-37
lines changed

3 files changed

+15
-37
lines changed

backends/cadence/aot/compiler.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from executorch.backends.cadence.aot.utils import (
2525
get_default_memory_config,
2626
MemoryConfig,
27-
model_is_quantized,
2827
)
2928
from executorch.devtools import generate_etrecord
3029
from executorch.exir import (
@@ -38,7 +37,6 @@
3837
from executorch.exir.passes import ToOutVarPass
3938
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
4039
from torch._inductor.decomposition import remove_decompositions
41-
from torch.ao.quantization.pt2e.export_utils import model_is_exported
4240
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4341

4442
from torch.export import export
@@ -158,26 +156,10 @@ def export_program(
158156
) -> ExportedProgram:
159157
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
160158

161-
# We don't support training mode. Make the model inference mode by
162-
# calling model.eval() or an equivalent call for quantized models.
163-
# GraphModules cannot call eval(), so we skip them.
164-
if not isinstance(model, torch.fx.GraphModule):
165-
if hasattr(model, "eval"):
166-
model.eval()
167-
else:
168-
# If the model is quantized, call the suggested torch.ao.quantization API
169-
# which only does dropout and batchnorm.
170-
if model_is_quantized(model):
171-
torch.ao.quantization.move_exported_model_to_eval(model)
172-
else:
173-
# If we get a GraphModule which is _not_ quantized, then it should already
174-
# have been exported.
175-
assert model_is_exported(model), "model should be from an ExportedProgram"
176-
177159
# Prevent mkldnn decompositions
178160
torch._C._set_mkldnn_enabled(False)
179161

180-
# else: capture the model and return it.
162+
# Export the model and return it.
181163
expo_program = export(model, inputs, strict=True)
182164

183165
if dump_graphs:
@@ -206,8 +188,8 @@ def export_to_edge(
206188
_skip_dim_order=True,
207189
# Allow specific non-core aten ops in the IR.
208190
_core_aten_ops_exception_list=[
191+
torch.ops.aten._native_batch_norm_legit_functional.default,
209192
torch.ops.aten.linear.default,
210-
torch.ops.aten.native_batch_norm.default,
211193
torch.ops.aten.linalg_vector_norm.default,
212194
torch.ops.aten.unfold.default,
213195
torch.ops.aten.angle.default,
@@ -226,10 +208,9 @@ def export_to_cadence(
226208
model: torch.nn.Module,
227209
inputs: tuple[object, ...],
228210
dump_graphs: bool = False,
229-
output_dir: Optional[str] = None,
230211
opt_level: int = 1,
231212
) -> EdgeProgramManager:
232-
edge_prog_manager = export_to_edge(model, inputs)
213+
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
233214
cadence_passes = get_cadence_passes(opt_level)
234215

235216
# Run a couple required passes for quant/dequant ops

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,15 @@ def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None:
183183
qconfig = _default_qconfig
184184
quantizers = get_cadence_default_quantizer_list_with_config(qconfig)
185185
super().__init__(quantizers)
186+
187+
188+
# Nop quantizer, used to run fp32 cases
189+
# Calls an empty list of quantizers (no quantization). Note
190+
# that we do not strictly need that class since we could call
191+
# CadenceQuantizer([]), but this is more explicit and
192+
# does not require knowledge of the internals of the base class.
193+
class CadenceNopQuantizer(CadenceQuantizer):
194+
def __init__(
195+
self,
196+
) -> None:
197+
super().__init__([])

backends/cadence/aot/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,9 @@
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
2121
from tabulate import tabulate
2222

23-
from torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops
2423
from torch.utils._pytree import tree_flatten
2524

2625

27-
# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
28-
def model_is_quantized(model: torch.nn.Module) -> bool:
29-
# Quantized models have to be GraphModules already, from prepare/convert calls.
30-
# Return false if the model is not a GraphModule.
31-
if not isinstance(model, torch.fx.GraphModule):
32-
return False
33-
34-
# Walk through the graph and look for quant/dequant ops
35-
for op in quant_ops:
36-
if model.graph.find_nodes(op="call_function", target=op):
37-
return True
38-
return False
39-
40-
4126
# Get the output size of a 1D convolution given the input size and parameters
4227
def get_conv1d_output_size(
4328
in_size: torch.Size,

0 commit comments

Comments
 (0)