Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-implement beam search on top of vllm core #8726

Merged
merged 26 commits into from
Sep 24, 2024
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve code
  • Loading branch information
youkaichao committed Sep 23, 2024
commit ceddba04e0e5846a48b8acd89ce62af9756e58c3
146 changes: 96 additions & 50 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import itertools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)

Expand Down Expand Up @@ -30,6 +32,37 @@
logger = init_logger(__name__)


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.

Expand Down Expand Up @@ -354,75 +387,88 @@ def generate(

def beam_search(
self,
prompts: Union[List[str], List[List[int]]],
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:

class BeamSearchSequence:
def __init__(self, tokens: List[int], logprob: float):
self.tokens = tokens
self.logprob = logprob

class BeamSearchInstance:
def __init__(self, tokens: List[int], logprob: float):
self.beams: List[BeamSearchSequence] = [BeamSearchSequence(tokens, logprob)]
self.completed: List[BeamSearchSequence] = []
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.

Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
"""

tokenizer = self.get_tokenizer()
beam_search_params = SamplingParams(logprobs=beam_width, max_tokens=1, temperature=0.0)
beam_search_params = SamplingParams(logprobs=beam_width,
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
max_tokens=1,
temperature=0.0)
instances: List[BeamSearchInstance] = []

for prompt in prompts:
tokens = prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
instance = BeamSearchInstance(tokens, logprob=0)
instances.append(instance)
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
instances.append(BeamSearchInstance(prompt_tokens))

for _ in range(max_tokens):
all_beams = []
instance_to_beam_index = {}

for i, instance in enumerate(instances):
for beam in instance.beams:
all_beams.append(beam)
instance_to_beam_index[len(all_beams) - 1] = i
all_beams: List[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))

if len(all_beams) == 0:
break

prompts_batch = [TokensPrompt(prompt_token_ids=beam.tokens) for beam in all_beams]
output = self.generate(prompts_batch, sampling_params=beam_search_params)

new_instance_beams: Dict[int, List[BeamSearchSequence]] = {i: [] for i in range(len(instances))}

for i, result in enumerate(output):
current_beam = all_beams[i]
instance_id = instance_to_beam_index[i]

logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_tokens = current_beam.tokens + [token_id]
new_logprob = current_beam.logprob + logprob_obj.logprob
new_beam = BeamSearchSequence(tokens=new_tokens, logprob=new_logprob)

if token_id == tokenizer.eos_token_id:
instances[instance_id].completed.append(new_beam)
else:
new_instance_beams[instance_id].append(new_beam)

for i, instance in enumerate(instances):
sorted_beams = sorted(new_instance_beams[i], key=lambda x: x.logprob, reverse=True)
instances[i].beams = sorted_beams[:beam_width]
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
output = self.generate(prompts_batch,
sampling_params=beam_search_params)

for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]

if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

if token_id == tokenizer.eos_token_id:
instance.completed.append(new_beam)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
instance.beams = sorted_beams[:beam_width]

outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed, key=lambda x: x.logprob, reverse=True)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]

final_tokens_list = [beam.tokens for beam in best_beams]
final_sequence_list = [tokenizer.decode(beam.tokens) for beam in best_beams]
outputs.append((final_tokens_list, final_sequence_list))
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))

return outputs

Expand Down