Skip to content

Commit

Permalink
[Core] Optimize sampler get_logprobs (vllm-project#4594)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored May 8, 2024
1 parent cc466a3 commit d7740ea
Showing 1 changed file with 68 additions and 49 deletions.
117 changes: 68 additions & 49 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,14 @@ def _get_logprobs(
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.cpu()
top_token_ids = top_token_ids.cpu()
else:
top_logprobs, top_token_ids = None, None

selected_logprobs = selected_logprobs.cpu()
ranks = ranks.cpu()
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')

# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
Expand Down Expand Up @@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed(

# Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None):
if is_prompt and sampling_params.prompt_logprobs is not None:
prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens:
# Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()

for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
token_id: (selected_logprob_items[idx], rank_items[idx])
}

# Add top K prompt logprobs along with its rank.
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
top_logprob_idx, :num_logprobs].tolist(),
# This is ranks. Since top_logprob is sorted,
# we can just use a range here.
range(1, num_logprobs + 1))))
top_ids = top_token_ids[
top_logprob_idx, :num_logprobs].tolist()
top_probs = top_logprobs[
top_logprob_idx, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
prompt_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
})
# + 1 to go to the next prompt token.
top_logprob_idx += 1
selected_logprobs_idx += 1

# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx += len(next_prompt_tokens)
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx


Expand All @@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
num_logprobs = seq_group.sampling_params.logprobs or 0
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result

if seq_group.do_sample:
assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id:
(selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
}
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx += 1

# Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
sampled_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
# Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})

sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.

# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.

# Iterate to the next sequence group in a batch.
selected_logprobs_idx += len(next_token_ids)
# Iterate to the next sequence group in a batch.
top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx

Expand Down

0 comments on commit d7740ea

Please sign in to comment.