Skip to content

Commit d6bc799

Browse files
authored
fix rsqrt test case
Differential Revision: D71762822 Pull Request resolved: #9556
1 parent f95d42f commit d6bc799

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
4343
cp.Size.Ge(lambda deps, r, d: 1),
4444
cp.Size.Le(lambda deps, r, d: 2**9),
4545
]
46-
case "sigmoid.default" | "rsqrt.default":
46+
case "sigmoid.default":
4747
additional_tensor_constraints.extend(
4848
[
4949
cp.Dtype.In(lambda deps: [torch.float]),
@@ -52,6 +52,17 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
5252
cp.Value.Le(lambda deps, dtype, struct: 2),
5353
]
5454
)
55+
case "rsqrt.default":
56+
additional_tensor_constraints.extend(
57+
[
58+
cp.Dtype.In(lambda deps: [torch.float]),
59+
cp.Rank.Le(lambda deps: 2**2),
60+
cp.Value.Gt(
61+
lambda deps, dtype, struct: 0
62+
), # only generate real numbers
63+
cp.Value.Le(lambda deps, dtype, struct: 2**2),
64+
]
65+
)
5566
case "mean.dim":
5667
additional_tensor_constraints.extend(
5768
[

0 commit comments

Comments
 (0)