Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] add sample_indices in sampling #16675

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 105 additions & 29 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,12 @@ def cumsum(
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)


def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"):
def multinomial_from_uniform(
prob: Tensor,
uniform_sample: Tensor,
sample_indices: Optional[Tensor] = None,
dtype: str = "int64",
):
"""Returns a tensor where each row contains the index sampled from the multinomial
probability distribution located in the corresponding row of tensor prob.

Expand All @@ -2075,57 +2080,97 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str =
The sum of values in each row is 1, forming a valid distribution.

uniform_sample : Tensor
The uniformly sampled 2-D tensor with the shape (batch, 1).
The uniformly sampled 2-D tensor with the shape (n, 1).
Values range from 0 to 1, indicating probabilities sampled uniformly.

sample_indices : Optional[Tensor]
The 2-D tensor with the shape [n, 1], which indicates the specific
probability distribution to sample from. The value of sample_indices[i]
determines that the ith token should be sampled from the sample_indices[i]th
probability distribution. For instance, if there are 3 distinct probability
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].

dtype : str
The data type of output tensor.


Returns
-------
result : Tensor
The computed tensor with shape (batch, 1).
The computed tensor with shape (n, 1).

Examples
--------
.. code-block:: python

prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
usample = [[0.4], [0.9]]
sample_indices = [[0], [1]]

multinomial_from_uniform(prob, usample)
-> [[1], [2]]
multinomial_from_uniform(prob, usample, sample_indices)
-> [[1], [2]]
"""
prob_dtype = prob.dtype
sample_dtype = uniform_sample.dtype
batch = prob.shape[0]
out_batch = uniform_sample.shape[0]

if sample_indices is not None:
assert (
sample_indices.shape == uniform_sample.shape
), "The shape of sample_indices must match the shape of uniform_sample."
else:
assert (
prob.shape[0] == uniform_sample.shape[0]
), "Number of samples must match the number of probability distributions."
sample_indices = Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1))

sample_indices_dtype = sample_indices.dtype

@T.prim_func(private=True)
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
usample = T.match_buffer(B, (batch, 1), sample_dtype)
output_index = T.match_buffer(C, (batch, 1), dtype)
out_batch = T.int64()
usample = T.match_buffer(B, (out_batch, 1), sample_dtype)
sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype)
output_index = T.match_buffer(D, (out_batch, 1), dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
if (
usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = 0
elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
elif (
usample[v_ax0, T.int64(0)]
>= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
):
output_index[v_ax0, 0] = v_ax1

cumsum_prob = cumsum(prob, axis=1, exclusive=False)

return tensor_ir_op(
_get_sample_index,
"get_sample_index",
args=[cumsum_prob, uniform_sample],
out=Tensor.placeholder([batch, 1], dtype),
args=[cumsum_prob, uniform_sample, sample_indices],
out=Tensor.placeholder([out_batch, 1], dtype),
)


def sample_top_p_top_k_from_sorted_prob(
sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor
sorted_prob: Tensor,
sorted_index: Tensor,
top_p: Tensor,
top_k: Tensor,
uniform_sample: Tensor,
sample_indices: Optional[Tensor] = None,
):
"""Samples indices from a sorted probability tensor based on top_p and top_k criteria.

Expand All @@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
to consider for top-k sampling.

uniform_sample : Tensor
Uniformly sampled values with shape (batch, 1) are used to select the output indices.
Uniformly sampled values with shape (n, 1) are used to select the output indices.

sample_indices : Optional[Tensor]
The 2-D tensor with the shape [n, 1], which indicates the specific
probability distribution to sample from. The value of sample_indices[i]
determines that the ith token should be sampled from the sample_indices[i]th
probability distribution. For instance, if there are 3 distinct probability
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].

Returns
-------
result : Tensor
The selected indices with shape (batch, 1).
The selected indices with shape (n, 1).

Examples
--------
Expand All @@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
top_p = [[0.6],[0.9]]
top_k = [[3],[2]]
uniform_sample = [[0.5], [0.6]]
sample_indices = [[0], [1]]

sample_top_p_top_k_from_sorted_prob(
sorted_prob, sorted_index,top_p, top_k, uniform_sample)
sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices)
-> [2, 0]

"""
prob_dtype = sorted_prob.dtype
index_dtype = sorted_index.dtype
batch = sorted_prob.shape[0]
prob_batch = sorted_prob.shape[0]
out_batch = uniform_sample.shape[0]

if sample_indices is not None:
assert (
sample_indices.shape == uniform_sample.shape
), "The shape of sample_indices must match the shape of uniform_sample."
else:
assert (
sorted_prob.shape[0] == uniform_sample.shape[0]
), "Number of samples must match the number of probability distributions."
sample_indices = Tensor.from_const(
np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
)
print("sample_indices: ", sample_indices)
sample_indices_dtype = sample_indices.dtype

def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
Expand All @@ -2204,27 +2273,34 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]

@T.prim_func(private=True)
def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle):
def _get_index_from_sorted(
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
):
batch, vocab_size = T.int64(), T.int64()
out_batch = T.int64()
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
usample = T.match_buffer(C, (batch, 1), prob_dtype)
indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
output_index = T.match_buffer(E, (batch, 1), index_dtype)
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
sample_indices = T.match_buffer(E, (out_batch, 1), sample_indices_dtype)
output_index = T.match_buffer(F, (out_batch, 1), index_dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if (
usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
usample[v_ax0, T.int64(0)]
< cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
elif (
usample[v_ax0, T.int64(0)]
>= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0]
>= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]

Expand All @@ -2235,16 +2311,16 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E
"get_renorm_prob",
args=[cumsum_sorted, top_p, top_k],
out=Tensor.placeholder(
[batch, 1],
[prob_batch, 1],
prob_dtype,
),
)

out_index_in_sorted = tensor_ir_op(
_get_index_from_sorted,
"get_index_from_sorted",
args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
out=Tensor.placeholder([batch, 1], index_dtype),
args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, sample_indices],
out=Tensor.placeholder([out_batch, 1], index_dtype),
)
return out_index_in_sorted

Expand Down Expand Up @@ -2293,7 +2369,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.
top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_prob"):
with T.block("T_get_renorm_cutoff"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
Expand Down
Loading
Loading