Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorByProfilePass,
)
from .rewrite_bool_bitwise_not_to_logical_not_pass import ( # noqa
RewriteBoolBitwiseNotToLogicalNotPass,
from .rewrite_bool_bitwise_to_logical_pass import ( # noqa
RewriteBoolBitwiseToLogicalPass,
)
from .rewrite_bool_to_fp32_cast_via_int8_pass import ( # noqa
RewriteBoolToFp32CastViaInt8Pass,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
RemoveNoopPass,
ReplaceInfAndLimitValuesPass,
ReplaceScalarWithTensorByProfilePass,
RewriteBoolBitwiseNotToLogicalNotPass,
RewriteBoolBitwiseToLogicalPass,
RewriteBoolToFp32CastViaInt8Pass,
RewriteConvPass,
RewriteIndexPutPass,
Expand Down Expand Up @@ -238,7 +238,6 @@ def _tosa_pipeline(
self.add_passes(
[
FuseQuantizedActivationPass(),
RewriteBoolBitwiseNotToLogicalNotPass(),
RewriteBoolToFp32CastViaInt8Pass(),
CanonicalizeGatherPass(),
ConvertToClampPass(),
Expand Down Expand Up @@ -326,6 +325,7 @@ def _tosa_pipeline(
DecomposeSliceScatterPass(),
AccumulateIndexPutPass(),
RewriteIndexPutPass(),
RewriteBoolBitwiseToLogicalPass(),
DecomposeRemainderPass(),
DecomposeDivTensorModePass(),
FuseBatchNorm2dPass(exported_program),
Expand Down

This file was deleted.

49 changes: 49 additions & 0 deletions backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class RewriteBoolBitwiseToLogicalPass(ArmPass):
"""Rewrites ``aten.bitwise_*`` on boolean tensors to ``aten.logical_*``.
TOSA ``bitwise_*`` does not support boolean inputs. On boolean tensors,
``bitwise_*`` is equivalent to ``logical_*``, so this rewrite preserves
semantics while enabling lowering.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

_TARGET_TO_LOGICAL = {
exir_ops.edge.aten.bitwise_not.default: exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.bitwise_and.Tensor: exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.bitwise_or.Tensor: exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.bitwise_xor.Tensor: exir_ops.edge.aten.logical_xor.default,
exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.logical_xor.default,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self._TARGET_TO_LOGICAL:
return super().call_operator(op, args, kwargs, meta)

if meta["val"].dtype == torch.bool:
return super().call_operator(
self._TARGET_TO_LOGICAL[op],
args,
kwargs,
meta,
updated=True,
)

return super().call_operator(op, args, kwargs, meta)
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import ( # noqa
as_strided_copy_support,
bool_bitwise_support,
clone_dim_order_support,
control_flow_support,
convolution_support,
Expand Down
39 changes: 39 additions & 0 deletions backends/arm/operator_support/bool_bitwise_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class BoolBitwiseSupported(SupportedTOSAOperatorCheck):
"""Allow boolean bitwise ops, which are lowered to logical ops."""

targets = [
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_and.Scalar,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_or.Scalar,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.bitwise_xor.Scalar,
exir_ops.edge.aten.bitwise_not.default,
]

tosa_specs = TosaSpecification.all_versions_and_profiles()

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
if node.meta["val"].dtype == torch.bool:
return True

return False
Loading
Loading