We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f95d42f commit d6bc799Copy full SHA for d6bc799
backends/cadence/utils/facto_util.py
@@ -43,7 +43,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
43
cp.Size.Ge(lambda deps, r, d: 1),
44
cp.Size.Le(lambda deps, r, d: 2**9),
45
]
46
- case "sigmoid.default" | "rsqrt.default":
+ case "sigmoid.default":
47
additional_tensor_constraints.extend(
48
[
49
cp.Dtype.In(lambda deps: [torch.float]),
@@ -52,6 +52,17 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
52
cp.Value.Le(lambda deps, dtype, struct: 2),
53
54
)
55
+ case "rsqrt.default":
56
+ additional_tensor_constraints.extend(
57
+ [
58
+ cp.Dtype.In(lambda deps: [torch.float]),
59
+ cp.Rank.Le(lambda deps: 2**2),
60
+ cp.Value.Gt(
61
+ lambda deps, dtype, struct: 0
62
+ ), # only generate real numbers
63
+ cp.Value.Le(lambda deps, dtype, struct: 2**2),
64
+ ]
65
+ )
66
case "mean.dim":
67
68
0 commit comments