Skip to content

Incrementally decode output tokens #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cacheflow/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def update(
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs)
seq.append_token_id(output.output_token, output.logprobs)
return self.running.copy()

def free_seq(self, seq: Sequence) -> None:
Expand Down
11 changes: 8 additions & 3 deletions cacheflow/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0

def append_token(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob

Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(

self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_tokens: List[str] = []
self.output_text = ""

self.logical_token_blocks: List[LogicalTokenBlock] = []
Expand Down Expand Up @@ -92,11 +93,15 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]

def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token(token_id, logprobs[token_id])
self.data.append_token_id(token_id, logprobs[token_id])

def get_len(self) -> int:
return self.data.get_len()
Expand Down
24 changes: 12 additions & 12 deletions cacheflow/server/llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
Expand Down Expand Up @@ -184,18 +185,17 @@ def step(self) -> List[RequestOutput]:
return request_outputs

def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Batch-decode the sequence outputs.
seqs: List[Sequence] = []
# Decode the sequence outputs.
for seq_group in seq_groups:
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
output_tokens_per_seq = []
for seq in seqs:
output_tokens_per_seq.append(seq.get_output_token_ids())
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
skip_special_tokens=True)
# Update the sequences with the output texts.
for seq, output_text in zip(seqs, output_texts):
seq.output_text = output_text
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text

def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
Expand Down
63 changes: 62 additions & 1 deletion cacheflow/server/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Union
from typing import List, Tuple, Union

from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)

from cacheflow.logger import init_logger

logger = init_logger(__name__)

_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
Expand All @@ -17,5 +21,62 @@ def get_tokenizer(
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if getattr(kwargs, "use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "
"bugs in the fast tokenizer.")
logger.info(
f"Using the slow tokenizer for {config.model_type} due to bugs in "
"the fast tokenizer. This could potentially lead to performance "
"degradation.")
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)


def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool,
) -> Tuple[str, str]:
"""Detokenizes the new token in conjuction with the previous output tokens.

NOTE: This function does not update prev_output_tokens.

Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]

# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text

# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
output_text = " ".join(sub_texts)
return new_token, output_text