Skip to content

Commit 9f8f912

Browse files
committed
Integrate NeutronAtenPassManager passes into pipeline
1 parent 0060581 commit 9f8f912

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from typing import List, Optional, Tuple, Union
88

99
import torch
10+
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
11+
NeutronAtenPassManager,
12+
)
1013

1114
from executorch.backends.nxp.quantizer.patterns import (
1215
AddmmPattern,
@@ -202,4 +205,5 @@ def __init__(self):
202205
def transform_for_annotation(
203206
self, model: torch.fx.GraphModule
204207
) -> torch.fx.GraphModule:
205-
return model
208+
pass_runner = NeutronAtenPassManager()
209+
return pass_runner(model).graph_module

backends/nxp/tests/executorch_pipeline.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from executorch import exir
99
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1010
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
11-
12-
# TODO (Robert Kalmar) Uncomment when NXP passes are ported to main
13-
# from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass_manager import NXPPyTorchPassManager
1411
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
1512
from executorch.exir import (
1613
EdgeCompileConfig,
@@ -27,7 +24,7 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
2724
quantizer = NeutronQuantizer()
2825

2926
m = prepare_pt2e(model, quantizer)
30-
for _i, data in enumerate(calibration_inputs):
27+
for data in calibration_inputs:
3128
m(*data)
3229
m = convert_pt2e(m)
3330

@@ -48,14 +45,8 @@ def to_quantized_edge_program(
4845
model, example_input, strict=True
4946
)
5047

51-
# TODO(Robert Kalmar) uncoment when NXP passes are ported to main
52-
# Run pre-processing passes of the float32 aten dialect program.
53-
# pass_manager = NXPPyTorchPassManager(exir_program_aten)
54-
# pass_manager.run() # All passes by default.
55-
56-
exir_program_aten_module = exir_program_aten.module()
5748
exir_program_aten__module_quant = _quantize_model(
58-
exir_program_aten_module, calibration_inputs
49+
exir_program_aten.module(), calibration_inputs
5950
)
6051

6152
compile_spec = generate_neutron_compile_spec(

0 commit comments

Comments
 (0)