Skip to content

Add mixed dtype check for XNNPACK partitioner #9612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def check_common_constraints(
return True

def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
# Check inputs are valid dtypes
# Check inputs are valid and have the same dtypes
# Gather all args which are nodes
args_to_check = []
reference_dtype = None
for arg in node.args:
if isinstance(arg, list) or isinstance(arg, tuple):
for item in arg:
Expand Down Expand Up @@ -174,11 +175,32 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
if arg_val.dtype not in valid_dtypes:
return False

# Use the first dtype as reference
reference_dtype = reference_dtype or arg_val.dtype

# Check for mixed dtypes
if arg_val.dtype != reference_dtype:
# Get op name if the attribute exists, otherwise use the full node target for logging
op_name = (
node.target.__name__
if hasattr(node.target, "__name__")
else str(node.target)
)
why(
node,
reason=(
f"{op_name} does not support mixed input dtypes, "
f"got: [{reference_dtype}, {arg_val.dtype}]"
),
)
return False

return True

def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
# Check outputs are valid dtype
# Check outputs are valid
node_val = node.meta.get("val", None)

if node_val is None:
return True

Expand Down
Loading