Skip to content

Arm backend: Move ReplaceScalarTensorWithFullPass to transforms #8998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
12 changes: 10 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
DecomposeSelectPass,
)
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
from executorch.backends.arm._passes.decompose_softmax_pass import ( # type: ignore[import-not-found]
DecomposeSoftmaxPass,
)
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import ( # type: ignore[import-not-found]
DecomposeSoftmaxUnstablePass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
Expand Down Expand Up @@ -85,6 +87,10 @@
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform

from executorch.backends.transforms.replace_scalar_tensor_with_full import ( # type: ignore[import-not-found]
ReplaceScalarTensorWithFullPass,
)

from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand Down Expand Up @@ -143,6 +149,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ReplaceScalarTensorWithFullPass())
self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand Down Expand Up @@ -213,4 +220,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeSoftmaxPass())

self.add_pass(ConvertMinMaxPass())
self.add_pass(ReplaceScalarTensorWithFullPass())
return self._transform(graph_module)
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def is_node_supported(
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
torch.ops.aten.scalar_tensor.default,
]

return supported
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class TestConformer(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
"torch.ops.aten._assert_scalar.default": 10,
"torch.ops.aten._local_scalar_dense.default": 1,
"torch.ops.aten.scalar_tensor.default": 2,
"torch.ops.higher_order.executorch_call_delegate": 6,
}

Expand Down
137 changes: 137 additions & 0 deletions backends/arm/test/ops/test_scalar_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester.tester import Quantize
from parameterized import parameterized


float_test_data_suite = [
# (test_name, scalar input, scalar input type,)
(
"scalar_tensor_float_1",
3.7,
torch.float32,
),
(
"scalar_tensor_float_2",
66,
torch.float32,
),
]

int_test_data_suite = [
# (test_name, scalar input, scalar input type,)
(
"scalar_tensor_int32",
33,
torch.int32,
),
(
"scalar_tensor_int8",
8,
torch.int8,
),
(
"scalar_tensor_int16",
16 * 16 * 16,
torch.int16,
),
]


class ScalarTensor(torch.nn.Module):
def __init__(self, scalar, dtype=torch.float32):
super().__init__()
self.scalar = scalar
self.dtype = dtype

def forward(self):
return torch.scalar_tensor(self.scalar, dtype=self.dtype)


class TestScalarTensor(unittest.TestCase):

def _test_scalar_tensor_tosa_MI_pipeline(
self, module: torch.nn.Module, expected_output
):
test_outputs = []
in_data = ()

(
ArmTester(
module,
example_inputs=in_data,
compile_spec=common.get_tosa_compile_spec(
"TOSA-0.80+MI",
),
)
.export()
.check_count({"torch.ops.aten.scalar_tensor.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_get_output(test_outputs, inputs=in_data)
)
self._verify_output(test_outputs, expected_output)

def _test_scalar_tensor_tosa_BI_pipeline(
self, module: torch.nn.Module, expected_output
):
test_outputs = []
in_data = ()
tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI")
compile_spec = common.get_tosa_compile_spec(tosa_spec)
quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())

(
ArmTester(
module,
example_inputs=in_data,
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.full.default": 1}) # Already replaced
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_get_output(test_outputs, inputs=in_data)
)
self._verify_output(test_outputs, expected_output)

def _verify_output(self, test_outputs, expected_output):
out_data = torch.squeeze(test_outputs[0][0])
assert out_data == expected_output
assert out_data.dtype == expected_output.dtype

@parameterized.expand(int_test_data_suite + float_test_data_suite)
def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types
self, test_name: str, scalar_value, scalar_type
):
scalar = scalar_value
dtype = scalar_type
self._test_scalar_tensor_tosa_MI_pipeline(
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
)

@parameterized.expand(float_test_data_suite)
def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type):
scalar = scalar_value
dtype = scalar_type
self._test_scalar_tensor_tosa_BI_pipeline(
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
)
54 changes: 54 additions & 0 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,60 @@ def serialize(
def is_quantized(self) -> bool:
return self.stages[self.stage_name(tester.Quantize)] is not None

def run_method_and_get_output(
self,
test_outputs: List,
inputs: Optional[Tuple[torch.Tensor]] = None,
stage: Optional[str] = None,
num_runs=1,
):
"""
Returns the run_artifact output of 'stage'. This output is returned as parameter of type List.
Returns self to allow the function to be run in a test chain.

Args:
stage: (Optional[str]): The name of the stage to compare.
The default is the latest run stage.
test_output: All output results.
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
The default is random data.
"""
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
if edge_stage is None:
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
assert (
edge_stage is not None
), "To get outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."

stage = stage or self.cur
test_stage = self.stages[stage]

exported_program = self.stages[self.stage_name(tester.Export)].artifact
output_nodes = get_output_nodes(exported_program)
output_qparams = get_output_quantization_params(output_nodes)

quantization_scales = []
for node in output_qparams:
quantization_scales.append(getattr(output_qparams[node], "scale", None))

# Loop inputs and get outputs of the test stage.
for run_iteration in range(num_runs):
reference_input = inputs if inputs else next(self.generate_random_inputs())

input_shapes = [
generated_input.shape if hasattr(generated_input, "shape") else (1,)
for generated_input in reference_input
]
input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")

test_output, _ = pytree.tree_flatten(
test_stage.run_artifact(reference_input)
)
test_outputs.append(test_output)

return self

def run_method_and_compare_outputs(
self,
inputs: Optional[Tuple[torch.Tensor]] = None,
Expand Down
35 changes: 6 additions & 29 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
)
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.replace_scalar_tensor_with_full import (
ReplaceScalarTensorWithFullPass,
)
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand Down Expand Up @@ -1723,35 +1726,9 @@ def call_operator(self, op, args, kwargs, meta):
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
scalar_tensor is not supported, so this is an opt_level=0 pass.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{"dtype": torch.float32},
meta,
)
register_cadence_pass(CadencePassAttribute(opt_level=0))(
ReplaceScalarTensorWithFullPass
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down
42 changes: 42 additions & 0 deletions backends/transforms/replace_scalar_tensor_with_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you this is the right thing to do.

# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Tuple

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch.fx.node import Argument


class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{"dtype": kwargs["dtype"]},
meta,
)
Loading