Skip to content

Commit 9b54ad5

Browse files
authored
Extend utilities for checking a scalar value (#2587)
Extend the `is_singleton_value` utility to check for singleton values that may be either 0D or 1D tensors. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 3505420 commit 9b54ad5

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

onnxscript/rewriter/_ir_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
7878
return None
7979

8080

81-
def get_singleton_value(val: ir.Value | None, rank: int | None = None):
81+
def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None):
8282
"""Returns element of a single element tensor constant value, and None otherwise.
8383
84-
If rank is specified, it checks that the value has the given rank.
84+
If an int rank is specified, it checks that the value has the given rank.
85+
If the rank is a sequence of ints, it checks that the value has one of the given ranks.
86+
87+
Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and
88+
`rank=(0,1)` checks for either a scalar or a 1D tensor.
8589
"""
8690
np_val = get_numpy_value(val)
8791
if np_val is not None and np_val.size == 1:
88-
if rank is None or (np_val.ndim == rank):
89-
return np_val.item()
92+
value = np_val.item()
93+
if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)):
94+
return value
95+
if isinstance(rank, Sequence) and (np_val.ndim in rank):
96+
return value
9097
return None
9198

9299

93100
def is_singleton_value(
94-
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
101+
val: ir.Value | None,
102+
expected: float | int | Callable,
103+
*,
104+
rtol: float | None = None,
105+
rank: int | Sequence[int] | None = None,
95106
) -> bool:
96107
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
97-
scalar = get_singleton_value(val)
108+
scalar = get_singleton_value(val, rank=rank)
98109
if scalar is None:
99110
return False
100111
if callable(expected):

onnxscript/rewriter/rules/fusion/_rotary_embedding.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,9 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
4343
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
4444
check_result = pattern.MatchResult()
4545

46-
def is_one(val):
47-
"""Check if val is a 0/1 dimensional tensor with a single element equal to 1."""
48-
np_val = _ir_utils.get_numpy_value(val)
49-
return (
50-
np_val is not None
51-
and np_val.size == 1
52-
and np_val.ndim <= 1
53-
and np_val.item() == 1
54-
)
55-
56-
if not is_one(one1):
46+
if not _ir_utils.is_singleton_value(one1, 1):
5747
return check_result.fail("Unsqueeze axes is not [1]", one1)
58-
if not is_one(one2):
48+
if not _ir_utils.is_singleton_value(one2, 1):
5949
return check_result.fail("Unsqueeze axes is not [1]", one2)
6050

6151
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)

0 commit comments

Comments
 (0)