Skip to content

Commit f4b6ff4

Browse files
author
Zonglin Peng
committed
Update on "jarvis-nightly-operators-test-aten-clamp-out"
Differential Revision: [D85364552](https://our.internmc.facebook.com/intern/diff/D85364552/) [ghstack-poisoned]
2 parents 9c67fe2 + 806d4cc commit f4b6ff4

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,26 @@ def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int
3131
"""
3232
Generate valid permutations using only positive dimension indices.
3333
This is required for Cadence/Xtensa kernels that don't support negative indexing.
34-
34+
3535
Args:
3636
tensor: Input tensor to generate permutations for
3737
length: Number of dimensions in the permutation (must equal tensor.dim())
38-
38+
3939
Returns:
4040
Set of valid permutation tuples containing only positive indices [0, rank-1]
4141
"""
4242
if length > tensor.dim():
4343
return set()
44-
44+
4545
n = tensor.dim()
4646
pool = list(range(n))
47-
47+
4848
# Generate multiple valid permutations (only positive indices)
4949
permutations: set[tuple[int, ...]] = set()
5050
for _ in range(3): # Generate 3 different permutations for diversity
5151
perm = tuple(rm.get_random().sample(pool, length))
5252
permutations.add(perm)
53-
53+
5454
return permutations
5555

5656

@@ -202,7 +202,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
202202
cp.Value.Le(lambda deps, dtype, struct: 2**4),
203203
cp.Rank.Ge(lambda deps: 1),
204204
cp.Size.Ge(lambda deps, r, d: 1),
205-
cp.Size.In(lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)),
205+
cp.Size.In(
206+
lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)
207+
),
206208
max_size_constraint,
207209
]
208210
else: # input tensor(b)
@@ -213,7 +215,11 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
213215
cp.Value.Le(lambda deps, dtype, struct: 2**4),
214216
cp.Rank.Ge(lambda deps: 1),
215217
cp.Size.Ge(lambda deps, r, d: 1),
216-
cp.Size.In(lambda deps, r, d: fn.broadcast_with(fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d)),
218+
cp.Size.In(
219+
lambda deps, r, d: fn.broadcast_with(
220+
fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d
221+
)
222+
),
217223
max_size_constraint,
218224
]
219225
case "embedding.default":
@@ -365,7 +371,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
365371
if index == 1: # Only apply zero-prevention to divisor
366372
tensor_constraints.extend(
367373
[
368-
cp.Value.Ne(lambda deps, dtype, struct: 0), # Prevent division by zero
374+
cp.Value.Ne(
375+
lambda deps, dtype, struct: 0
376+
), # Prevent division by zero
369377
cp.Value.Le(lambda deps, dtype, struct: 2**3),
370378
cp.Size.Le(lambda deps, r, d: 2**3),
371379
cp.Rank.Le(lambda deps: 2**2),
@@ -400,7 +408,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
400408
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
401409
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
402410
cp.Value.Le(lambda deps, dtype, struct: 2**4),
403-
cp.Value.Ne(lambda deps, dtype, struct: 0), # Prevent division by zero
411+
cp.Value.Ne(
412+
lambda deps, dtype, struct: 0
413+
), # Prevent division by zero
404414
cp.Rank.Ge(lambda deps: 1),
405415
cp.Rank.Eq(lambda deps: deps[0].dim()),
406416
cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)),
@@ -485,7 +495,9 @@ def facto_testcase_gen( # noqa: C901
485495
[
486496
cp.Optional.Eq(lambda deps: False), # Never None
487497
cp.Value.Ge(lambda deps, dtype: -(2**4)),
488-
cp.Value.Le(lambda deps, dtype: 2**4 - 2), # Leave room for max (at least 2 units)
498+
cp.Value.Le(
499+
lambda deps, dtype: 2**4 - 2
500+
), # Leave room for max (at least 2 units)
489501
]
490502
)
491503
elif in_spec.name == "max":
@@ -494,7 +506,9 @@ def facto_testcase_gen( # noqa: C901
494506
spec.inspec[index].constraints.extend(
495507
[
496508
cp.Optional.Eq(lambda deps: False), # Never None
497-
cp.Value.Ge(lambda deps, dtype: deps[1] + 2), # max >= min + 2 (sufficient gap)
509+
cp.Value.Ge(
510+
lambda deps, dtype: deps[1] + 2
511+
), # max >= min + 2 (sufficient gap)
498512
cp.Value.Le(lambda deps, dtype: 2**4),
499513
]
500514
)
@@ -540,7 +554,9 @@ def facto_testcase_gen( # noqa: C901
540554
spec.inspec[index].constraints.extend(
541555
[
542556
cp.Length.Ge(lambda deps: 1),
543-
cp.Length.Eq(lambda deps: deps[0].dim()), # Must be a complete permutation
557+
cp.Length.Eq(
558+
lambda deps: deps[0].dim()
559+
), # Must be a complete permutation
544560
cp.Optional.Eq(lambda deps: False),
545561
# Generate valid permutations using only positive indices
546562
# Cadence/Xtensa hardware kernels do not support negative dimension indices

0 commit comments

Comments
 (0)