-
Notifications
You must be signed in to change notification settings - Fork 537
Arm backend: Move ReplaceScalarTensorWithFullPass to transforms #8998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
dc8c21a
Arm backend: Move ReplaceScalarTensorWithFullPass to transforms
mansnils f52c63c
Update conformer model test
mansnils 04be2a8
Merge remote-tracking branch 'upstream/main' into ops
mansnils edc1470
Fix formatting
mansnils ada2b5a
Merge remote-tracking branch 'upstream/main' into ops
mansnils 8875bc5
Merge branch 'main' into ops
zingo 3b358e9
Merge branch 'main' into ops
zingo 4061784
Merge branch 'main' into ops
zingo 1ddccd7
Merge branch 'main' into ops
zingo 3c7cbb8
Merge branch 'main' into ops
zingo 69879d5
Merge branch 'main' into ops
zingo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# Copyright 2024-2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
get_symmetric_quantization_config, | ||
TOSAQuantizer, | ||
) | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.arm.tosa_specification import TosaSpecification | ||
from executorch.backends.xnnpack.test.tester.tester import Quantize | ||
from parameterized import parameterized | ||
|
||
|
||
float_test_data_suite = [ | ||
# (test_name, scalar input, scalar input type,) | ||
( | ||
"scalar_tensor_float_1", | ||
3.7, | ||
torch.float32, | ||
), | ||
( | ||
"scalar_tensor_float_2", | ||
66, | ||
torch.float32, | ||
), | ||
] | ||
|
||
int_test_data_suite = [ | ||
# (test_name, scalar input, scalar input type,) | ||
( | ||
"scalar_tensor_int32", | ||
33, | ||
torch.int32, | ||
), | ||
( | ||
"scalar_tensor_int8", | ||
8, | ||
torch.int8, | ||
), | ||
( | ||
"scalar_tensor_int16", | ||
16 * 16 * 16, | ||
torch.int16, | ||
), | ||
] | ||
|
||
|
||
class ScalarTensor(torch.nn.Module): | ||
def __init__(self, scalar, dtype=torch.float32): | ||
super().__init__() | ||
self.scalar = scalar | ||
self.dtype = dtype | ||
|
||
def forward(self): | ||
return torch.scalar_tensor(self.scalar, dtype=self.dtype) | ||
|
||
|
||
class TestScalarTensor(unittest.TestCase): | ||
|
||
def _test_scalar_tensor_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, expected_output | ||
): | ||
test_outputs = [] | ||
in_data = () | ||
|
||
( | ||
ArmTester( | ||
module, | ||
example_inputs=in_data, | ||
compile_spec=common.get_tosa_compile_spec( | ||
"TOSA-0.80+MI", | ||
), | ||
) | ||
.export() | ||
.check_count({"torch.ops.aten.scalar_tensor.default": 1}) | ||
.to_edge_transform_and_lower() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_get_output(test_outputs, inputs=in_data) | ||
) | ||
self._verify_output(test_outputs, expected_output) | ||
|
||
def _test_scalar_tensor_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, expected_output | ||
): | ||
test_outputs = [] | ||
in_data = () | ||
tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") | ||
compile_spec = common.get_tosa_compile_spec(tosa_spec) | ||
quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) | ||
|
||
( | ||
ArmTester( | ||
module, | ||
example_inputs=in_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check_count({"torch.ops.aten.full.default": 1}) # Already replaced | ||
.to_edge_transform_and_lower() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_get_output(test_outputs, inputs=in_data) | ||
) | ||
self._verify_output(test_outputs, expected_output) | ||
|
||
def _verify_output(self, test_outputs, expected_output): | ||
out_data = torch.squeeze(test_outputs[0][0]) | ||
assert out_data == expected_output | ||
assert out_data.dtype == expected_output.dtype | ||
|
||
@parameterized.expand(int_test_data_suite + float_test_data_suite) | ||
def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types | ||
self, test_name: str, scalar_value, scalar_type | ||
): | ||
scalar = scalar_value | ||
dtype = scalar_type | ||
self._test_scalar_tensor_tosa_MI_pipeline( | ||
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) | ||
) | ||
|
||
@parameterized.expand(float_test_data_suite) | ||
def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type): | ||
scalar = scalar_value | ||
dtype = scalar_type | ||
self._test_scalar_tensor_tosa_BI_pipeline( | ||
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype) | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict, Tuple | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue | ||
from torch.fx.node import Argument | ||
|
||
|
||
class ReplaceScalarTensorWithFullPass(ExportPass): | ||
""" | ||
aten.scalar_tensor can be replaced by aten.full with a shape of [1]. | ||
""" | ||
|
||
def call_operator( | ||
self, | ||
op, | ||
args: Tuple[Argument, ...], | ||
kwargs: Dict[str, Argument], | ||
meta: NodeMetadata, | ||
) -> ProxyValue: | ||
if op not in { | ||
exir_ops.edge.aten.scalar_tensor.default, | ||
torch.ops.aten.scalar_tensor.default, | ||
}: | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
return super().call_operator( | ||
exir_ops.edge.aten.full.default, | ||
( | ||
[1], | ||
args[0], | ||
), | ||
{"dtype": kwargs["dtype"]}, | ||
meta, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you this is the right thing to do.