24
24
from executorch .backends .cadence .aot .utils import (
25
25
get_default_memory_config ,
26
26
MemoryConfig ,
27
- model_is_quantized ,
28
27
)
29
28
from executorch .devtools import generate_etrecord
30
29
from executorch .exir import (
38
37
from executorch .exir .passes import ToOutVarPass
39
38
from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
40
39
from torch ._inductor .decomposition import remove_decompositions
41
- from torch .ao .quantization .pt2e .export_utils import model_is_exported
42
40
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
43
41
44
42
from torch .export import export
@@ -158,26 +156,10 @@ def export_program(
158
156
) -> ExportedProgram :
159
157
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
160
158
161
- # We don't support training mode. Make the model inference mode by
162
- # calling model.eval() or an equivalent call for quantized models.
163
- # GraphModules cannot call eval(), so we skip them.
164
- if not isinstance (model , torch .fx .GraphModule ):
165
- if hasattr (model , "eval" ):
166
- model .eval ()
167
- else :
168
- # If the model is quantized, call the suggested torch.ao.quantization API
169
- # which only does dropout and batchnorm.
170
- if model_is_quantized (model ):
171
- torch .ao .quantization .move_exported_model_to_eval (model )
172
- else :
173
- # If we get a GraphModule which is _not_ quantized, then it should already
174
- # have been exported.
175
- assert model_is_exported (model ), "model should be from an ExportedProgram"
176
-
177
159
# Prevent mkldnn decompositions
178
160
torch ._C ._set_mkldnn_enabled (False )
179
161
180
- # else: capture the model and return it.
162
+ # Export the model and return it.
181
163
expo_program = export (model , inputs , strict = True )
182
164
183
165
if dump_graphs :
@@ -206,8 +188,8 @@ def export_to_edge(
206
188
_skip_dim_order = True ,
207
189
# Allow specific non-core aten ops in the IR.
208
190
_core_aten_ops_exception_list = [
191
+ torch .ops .aten ._native_batch_norm_legit_functional .default ,
209
192
torch .ops .aten .linear .default ,
210
- torch .ops .aten .native_batch_norm .default ,
211
193
torch .ops .aten .linalg_vector_norm .default ,
212
194
torch .ops .aten .unfold .default ,
213
195
torch .ops .aten .angle .default ,
@@ -226,10 +208,9 @@ def export_to_cadence(
226
208
model : torch .nn .Module ,
227
209
inputs : tuple [object , ...],
228
210
dump_graphs : bool = False ,
229
- output_dir : Optional [str ] = None ,
230
211
opt_level : int = 1 ,
231
212
) -> EdgeProgramManager :
232
- edge_prog_manager = export_to_edge (model , inputs )
213
+ edge_prog_manager = export_to_edge (model , inputs , dump_graphs = dump_graphs )
233
214
cadence_passes = get_cadence_passes (opt_level )
234
215
235
216
# Run a couple required passes for quant/dequant ops
0 commit comments