Skip to content

Commit 2c10b33

Browse files
committed
Add why log for mixed input dtypes; remove mixed dtype check from output
1 parent 6bd2196 commit 2c10b33

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,22 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
176176
return False
177177

178178
# Check for mixed dtypes
179-
if reference_dtype is None:
180-
reference_dtype = arg_val.dtype
181-
elif arg_val.dtype != reference_dtype:
179+
reference_dtype = reference_dtype or arg_val.dtype
180+
if arg_val.dtype != reference_dtype:
181+
why(
182+
node,
183+
reason=(
184+
f"{node.target} does not support mixed input dtypes. "
185+
f"Got: [{reference_dtype}, {arg_val.dtype}]"
186+
),
187+
)
182188
return False
183189

184190
return True
185191

186192
def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
187-
# Check outputs are valid and have the same dtypes
193+
# Check outputs are valid
188194
node_val = node.meta.get("val", None)
189-
reference_dtype = None
190195

191196
if node_val is None:
192197
return True
@@ -201,12 +206,6 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
201206
if val.dtype not in valid_dtypes:
202207
return False
203208

204-
# Check for mixed dtypes
205-
if reference_dtype is None:
206-
reference_dtype = val.dtype
207-
elif val.dtype != reference_dtype:
208-
return False
209-
210209
return True
211210

212211
def _check_node_has_valid_dtype(self, node):

0 commit comments

Comments
 (0)