Skip to content

Commit 62daa78

Browse files
author
Jan Jirmasek
committed
fix the unit tests expectations
1 parent 5fdd5e3 commit 62daa78

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

coremltools/converters/mil/mil/passes/tests/test_passes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7471,8 +7471,8 @@ def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):
74717471

74727472
assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype)
74737473
assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length
7474-
assert ops_counts[2] >= 26 * 16
7475-
assert ops_counts[3] >= 26 * 32
7474+
assert ops_counts[2] >= 11 * 16 # 11 ops (without consts) per slice
7475+
assert ops_counts[3] >= 11 * 32
74767476

74777477
predict_inputs = copy.deepcopy(example_inputs)
74787478
if "attn_mask" in predict_inputs:

0 commit comments

Comments
 (0)