Skip to content

Commit

Permalink
[Relax] add sample_indices in sampling (#16675)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored Mar 5, 2024
1 parent 46aaf61 commit fe5a350
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 92 deletions.
134 changes: 105 additions & 29 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,12 @@ def cumsum(
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)


def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"):
def multinomial_from_uniform(
prob: Tensor,
uniform_sample: Tensor,
sample_indices: Optional[Tensor] = None,
dtype: str = "int64",
):
"""Returns a tensor where each row contains the index sampled from the multinomial
probability distribution located in the corresponding row of tensor prob.
Expand All @@ -2075,57 +2080,97 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str =
The sum of values in each row is 1, forming a valid distribution.
uniform_sample : Tensor
The uniformly sampled 2-D tensor with the shape (batch, 1).
The uniformly sampled 2-D tensor with the shape (n, 1).
Values range from 0 to 1, indicating probabilities sampled uniformly.
sample_indices : Optional[Tensor]
The 2-D tensor with the shape [n, 1], which indicates the specific
probability distribution to sample from. The value of sample_indices[i]
determines that the ith token should be sampled from the sample_indices[i]th
probability distribution. For instance, if there are 3 distinct probability
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
dtype : str
The data type of output tensor.
Returns
-------
result : Tensor
The computed tensor with shape (batch, 1).
The computed tensor with shape (n, 1).
Examples
--------
.. code-block:: python
prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
usample = [[0.4], [0.9]]
sample_indices = [[0], [1]]
multinomial_from_uniform(prob, usample)
-> [[1], [2]]
multinomial_from_uniform(prob, usample, sample_indices)
-> [[1], [2]]
"""
prob_dtype = prob.dtype
sample_dtype = uniform_sample.dtype
batch = prob.shape[0]
out_batch = uniform_sample.shape[0]

if sample_indices is not None:
assert (
sample_indices.shape == uniform_sample.shape
), "The shape of sample_indices must match the shape of uniform_sample."
else:
assert (
prob.shape[0] == uniform_sample.shape[0]
), "Number of samples must match the number of probability distributions."
sample_indices = Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1))

sample_indices_dtype = sample_indices.dtype

@T.prim_func(private=True)
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
usample = T.match_buffer(B, (batch, 1), sample_dtype)
output_index = T.match_buffer(C, (batch, 1), dtype)
out_batch = T.int64()
usample = T.match_buffer(B, (out_batch, 1), sample_dtype)
sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype)
output_index = T.match_buffer(D, (out_batch, 1), dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
if (
usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = 0
elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
elif (
usample[v_ax0, T.int64(0)]
>= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
):
output_index[v_ax0, 0] = v_ax1

cumsum_prob = cumsum(prob, axis=1, exclusive=False)

return tensor_ir_op(
_get_sample_index,
"get_sample_index",
args=[cumsum_prob, uniform_sample],
out=Tensor.placeholder([batch, 1], dtype),
args=[cumsum_prob, uniform_sample, sample_indices],
out=Tensor.placeholder([out_batch, 1], dtype),
)


def sample_top_p_top_k_from_sorted_prob(
sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor
sorted_prob: Tensor,
sorted_index: Tensor,
top_p: Tensor,
top_k: Tensor,
uniform_sample: Tensor,
sample_indices: Optional[Tensor] = None,
):
"""Samples indices from a sorted probability tensor based on top_p and top_k criteria.
Expand All @@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
to consider for top-k sampling.
uniform_sample : Tensor
Uniformly sampled values with shape (batch, 1) are used to select the output indices.
Uniformly sampled values with shape (n, 1) are used to select the output indices.
sample_indices : Optional[Tensor]
The 2-D tensor with the shape [n, 1], which indicates the specific
probability distribution to sample from. The value of sample_indices[i]
determines that the ith token should be sampled from the sample_indices[i]th
probability distribution. For instance, if there are 3 distinct probability
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
Returns
-------
result : Tensor
The selected indices with shape (batch, 1).
The selected indices with shape (n, 1).
Examples
--------
Expand All @@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
top_p = [[0.6],[0.9]]
top_k = [[3],[2]]
uniform_sample = [[0.5], [0.6]]
sample_indices = [[0], [1]]
sample_top_p_top_k_from_sorted_prob(
sorted_prob, sorted_index,top_p, top_k, uniform_sample)
sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices)
-> [2, 0]
"""
prob_dtype = sorted_prob.dtype
index_dtype = sorted_index.dtype
batch = sorted_prob.shape[0]
prob_batch = sorted_prob.shape[0]
out_batch = uniform_sample.shape[0]

if sample_indices is not None:
assert (
sample_indices.shape == uniform_sample.shape
), "The shape of sample_indices must match the shape of uniform_sample."
else:
assert (
sorted_prob.shape[0] == uniform_sample.shape[0]
), "Number of samples must match the number of probability distributions."
sample_indices = Tensor.from_const(
np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
)
print("sample_indices: ", sample_indices)
sample_indices_dtype = sample_indices.dtype

def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
Expand All @@ -2204,27 +2273,34 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]

@T.prim_func(private=True)
def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle):
def _get_index_from_sorted(
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
):
batch, vocab_size = T.int64(), T.int64()
out_batch = T.int64()
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
usample = T.match_buffer(C, (batch, 1), prob_dtype)
indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
output_index = T.match_buffer(E, (batch, 1), index_dtype)
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
sample_indices = T.match_buffer(E, (out_batch, 1), sample_indices_dtype)
output_index = T.match_buffer(F, (out_batch, 1), index_dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if (
usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
usample[v_ax0, T.int64(0)]
< cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
elif (
usample[v_ax0, T.int64(0)]
>= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0]
>= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]

Expand All @@ -2235,16 +2311,16 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E
"get_renorm_prob",
args=[cumsum_sorted, top_p, top_k],
out=Tensor.placeholder(
[batch, 1],
[prob_batch, 1],
prob_dtype,
),
)

out_index_in_sorted = tensor_ir_op(
_get_index_from_sorted,
"get_index_from_sorted",
args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
out=Tensor.placeholder([batch, 1], index_dtype),
args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, sample_indices],
out=Tensor.placeholder([out_batch, 1], index_dtype),
)
return out_index_in_sorted

Expand Down Expand Up @@ -2293,7 +2369,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.
top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_prob"):
with T.block("T_get_renorm_cutoff"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
Expand Down
Loading

0 comments on commit fe5a350

Please sign in to comment.