Skip to content

Commit 2532d80

Browse files
author
Zonglin Peng
committed
Update base for Update on "jarvis-nightly-operators-test-aten-where-out"
Differential Revision: [D85364554](https://our.internmc.facebook.com/intern/diff/D85364554/) [ghstack-poisoned]
1 parent 1ebddff commit 2532d80

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,26 @@ def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int
3131
"""
3232
Generate valid permutations using only positive dimension indices.
3333
This is required for Cadence/Xtensa kernels that don't support negative indexing.
34-
34+
3535
Args:
3636
tensor: Input tensor to generate permutations for
3737
length: Number of dimensions in the permutation (must equal tensor.dim())
38-
38+
3939
Returns:
4040
Set of valid permutation tuples containing only positive indices [0, rank-1]
4141
"""
4242
if length > tensor.dim():
4343
return set()
44-
44+
4545
n = tensor.dim()
4646
pool = list(range(n))
47-
47+
4848
# Generate multiple valid permutations (only positive indices)
4949
permutations: set[tuple[int, ...]] = set()
5050
for _ in range(3): # Generate 3 different permutations for diversity
5151
perm = tuple(rm.get_random().sample(pool, length))
5252
permutations.add(perm)
53-
53+
5454
return permutations
5555

5656

@@ -378,7 +378,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
378378
if index == 1: # Only apply zero-prevention to divisor
379379
tensor_constraints.extend(
380380
[
381-
cp.Value.Ne(lambda deps, dtype, struct: 0), # Prevent division by zero
381+
cp.Value.Ne(
382+
lambda deps, dtype, struct: 0
383+
), # Prevent division by zero
382384
cp.Value.Le(lambda deps, dtype, struct: 2**3),
383385
cp.Size.Le(lambda deps, r, d: 2**3),
384386
cp.Rank.Le(lambda deps: 2**2),
@@ -413,7 +415,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
413415
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
414416
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
415417
cp.Value.Le(lambda deps, dtype, struct: 2**4),
416-
cp.Value.Ne(lambda deps, dtype, struct: 0), # Prevent division by zero
418+
cp.Value.Ne(
419+
lambda deps, dtype, struct: 0
420+
), # Prevent division by zero
417421
cp.Rank.Ge(lambda deps: 1),
418422
cp.Rank.Eq(lambda deps: deps[0].dim()),
419423
cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)),
@@ -531,7 +535,9 @@ def facto_testcase_gen( # noqa: C901
531535
spec.inspec[index].constraints.extend(
532536
[
533537
cp.Length.Ge(lambda deps: 1),
534-
cp.Length.Eq(lambda deps: deps[0].dim()), # Must be a complete permutation
538+
cp.Length.Eq(
539+
lambda deps: deps[0].dim()
540+
), # Must be a complete permutation
535541
cp.Optional.Eq(lambda deps: False),
536542
# Generate valid permutations using only positive indices
537543
# Cadence/Xtensa hardware kernels do not support negative dimension indices

0 commit comments

Comments
 (0)