Skip to content

Arm backend: Update operator support for TOSA-1.0+INT+u55 #10849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
Expand Down Expand Up @@ -92,7 +92,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchWhereSelfDtypePass())
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
if self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
Expand Down Expand Up @@ -210,7 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeSqrtPass())
self.add_pass(DecomposeSiluPass())

if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
if self.tosa_spec.is_U55_subset:
# Numerically stable softmax uses amax which is not supported on Ethos-U55
self.add_pass(DecomposeSoftmaxUnstablePass())
else:
Expand Down
16 changes: 5 additions & 11 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -46,13 +43,10 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
return False

# Hardware specific constraints
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
# TODO remove this once TOSA 1.0 support for u55 is added.
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
return False
return True
else:
if tosa_spec.is_U55_subset:
return self._is_node_supported_u55(node)
else:
return True

def _is_node_supported_u55(self, node: fx.Node):
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -46,7 +46,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
if not tosa_spec.is_U55_subset:
return True

# U55 case, Vela 4.2.0 (25.02 release)
Expand Down Expand Up @@ -104,7 +104,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
if not tosa_spec.is_U55_subset:
return True

# U55 case, Vela 4.2.0 (25.02 release)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


Expand All @@ -26,7 +26,7 @@ class SumSupported(SupportedTOSAOperatorCheck):
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
if not tosa_spec.is_U55_subset:
return True

# U55 case, Vela 4.2.0 (25.02 release)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
Expand All @@ -36,6 +36,6 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# TODO MLETORCH-525 Remove warning
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
if tosa_spec.is_U55_subset:
logging.warning(f"{node.target} may introduce one-off errors.")
return True
10 changes: 2 additions & 8 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
EthosU55NotSupported,
EthosU55TransposeCheck,
)
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -129,9 +125,7 @@ def tosa_support_factory(
if not tosa_spec.support_float():
negative_checks.append(NeedsDecompositionCheck(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
):
if tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/operators/op_rshift_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00


@register_node_visitor
Expand All @@ -39,7 +38,7 @@ def define_node(

attr = ts.TosaSerializerAttribute()
round = False
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
if self.tosa_spec.is_U55_subset:
# U55 only supports INT32 and round == True
# TODO MLETORCH-525 Emulate round == False with different decomposition
round = True
Expand Down Expand Up @@ -72,7 +71,7 @@ def define_node(

attr = ts.TosaSerializerAttribute()
round = False
if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions:
if self.tosa_spec.is_U55_subset:
# U55 only supports INT32 and round == True
# TODO MLETORCH-525 Emulate round == False with different decomposition
round = True
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/test/tester/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def __init__(
)
quant_stage = (
Quantize(
TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config()),
TOSAQuantizer(tosa_profiles[tosa_version]).set_io(
get_symmetric_quantization_config()
),
get_symmetric_quantization_config(),
)
if symmetric_io_quantization
Expand Down
17 changes: 10 additions & 7 deletions backends/arm/tosa_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TosaSpecification:
"""

version: Version
is_U55_subset: bool

def support_integer(self) -> bool:
"""
Expand All @@ -49,9 +50,13 @@ def support_float(self) -> bool:
"""
raise NotImplementedError

def __init__(self, version: Version):
def __init__(self, version: Version, extras: List[str]):
self.version = version

self.is_U55_subset = "u55" in extras
if self.is_U55_subset:
extras.remove("u55")

@staticmethod
def create_from_string(repr: str) -> "TosaSpecification":
"""
Expand Down Expand Up @@ -85,11 +90,10 @@ def create_from_string(repr: str) -> "TosaSpecification":
class Tosa_0_80(TosaSpecification):
profile: str
level_8k: bool
is_U55_subset: bool
available_profiles = ["BI", "MI"] # MT is not defined

def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
super().__init__(version, extras)
assert version >= Version("0.80") and version < Version("0.90")

# Check that we only have one profile in the extensions list
Expand All @@ -105,9 +109,6 @@ def __init__(self, version: Version, extras: List[str]):
self.level_8k = "8k" in extras
if self.level_8k:
extras.remove("8k")
self.is_U55_subset = "u55" in extras
if self.is_U55_subset:
extras.remove("u55")

if len(extras) > 0:
raise ValueError(f"Unhandled extras found: {extras}")
Expand Down Expand Up @@ -147,7 +148,7 @@ class Tosa_1_00(TosaSpecification):
}

def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
super().__init__(version, extras)

# Check that we have at least one profile in the extensions list
if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
Expand Down Expand Up @@ -194,6 +195,8 @@ def __repr__(self):
extensions = self._get_extensions_string()
if self.level_8k:
extensions += "+8k"
if self.is_U55_subset:
extensions += "+u55"
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"

def __hash__(self) -> int:
Expand Down
Loading