Skip to content

Speculative Decoding #2607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/api_client_spec_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Example code for running queries from vLLM API server.
Sample Usage:
1. Launch a vLLM server with speculative decoding enabled:
python -m vllm.entrypoints.api_server --model meta-llama/Llama-2-70b-hf \
--tensor-parallel-size 8 --draft-model TinyLlama/TinyLlama-1.1B-Chat-v0.6 --speculate-length 5
2. Run query using this script:
python api_client_spec_dec.py --prompt "San Francisco is a" --stream
"""

import argparse
import json
from typing import Iterable, List

import requests


def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)


def post_http_request(prompt: str,
api_url: str,
max_tokens: int = 256,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"temperature": 0.0,
"max_tokens": max_tokens,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response


def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_content(
chunk_size=8192,
decode_unicode=True,
):
if chunk:
data = json.loads(chunk.decode("utf-8")[:-1])
output = data["text"]
yield output


def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
max_tokens = args.max_tokens
stream = args.stream

print(f"Prompt: {prompt!r}\n", flush=True)
response = post_http_request(prompt, api_url, max_tokens, stream)

if stream:
num_printed_lines = 0
char_printed = 0
for h in get_streaming_response(response):
line = h[0]
new_chars = line[char_printed:]
char_printed = len(line)
print(f"{new_chars}", flush=True, end='')
num_printed_lines += 1
print()
else:
output = get_response(response)
line = output[0]
print(f"{line!r}", flush=True)
30 changes: 30 additions & 0 deletions examples/offline_inference_spec_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=512)

# Create an LLM.
llm = LLM(model="lmsys/vicuna-13b-v1.5",
draft_model="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
speculate_length=5)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
mean_num_accepted = np.mean(output.outputs[0].acceptance_history)
print(
f"Prompt: {prompt!r}, Generated text: {generated_text!r}, Mean acceptance length={mean_num_accepted}"
)
13 changes: 10 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
is_draft_model: Optional[bool] = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -93,6 +94,8 @@ def __init__(
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
# Flag to mark if the model is used as draft model for speculative decoding
self.is_draft_model = is_draft_model

if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
Expand Down Expand Up @@ -198,7 +201,8 @@ def verify_with_parallel_config(
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
tensor_parallel_size = parallel_config.draft_model_tp_size if self.is_draft_model \
else parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
Expand Down Expand Up @@ -269,8 +273,9 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
tensor_parallel_size = parallel_config.draft_model_tp_size if self.is_draft_model \
else parallel_config.tensor_parallel_size
return max(1, total_num_kv_heads // tensor_parallel_size)

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
Expand Down Expand Up @@ -378,12 +383,14 @@ def __init__(
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
draft_model_tp_size: Optional[int] = 1,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.draft_model_tp_size = draft_model_tp_size

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
Expand Down
54 changes: 49 additions & 5 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from vllm.utils import Device


def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b


class BlockAllocator:
"""Manages free physical token blocks for a device.

Expand Down Expand Up @@ -97,11 +101,16 @@ def __init__(
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_padding_tokens: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
seq_len = seq.data.get_len()
assert cdiv(seq_len, self.block_size) == len(seq.logical_token_blocks)
num_required_blocks = cdiv(seq_len + num_padding_tokens,
self.block_size)

if seq_group.prefix is not None and seq_group.prefix.allocated:
num_required_blocks -= seq_group.prefix.get_num_blocks()
Expand All @@ -120,13 +129,15 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
else:
return AllocStatus.LATER

def allocate(self, seq_group: SequenceGroup) -> None:
def allocate(self,
seq_group: SequenceGroup,
num_padding_tokens: int = 0) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]

seq_len = seq.data.get_len()
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
num_prompt_blocks = cdiv(seq_len + num_padding_tokens, self.block_size)

block_table: BlockTable = []
prefix_block_table: BlockTable = []
Expand Down Expand Up @@ -163,6 +174,16 @@ def allocate(self, seq_group: SequenceGroup) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_multiple_slots(self,
seq_group: SequenceGroup,
num_new_tokens: int = 1) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs * cdiv(num_new_tokens,
self.block_size) <= num_free_gpu_blocks

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
Expand Down Expand Up @@ -202,6 +223,29 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number

def append_multiple_slots(
self,
seq: Sequence,
num_new_tokens: int = 1) -> Optional[Tuple[int, int]]:
"""Allocate multiple physical slots for new tokens. This function is used in
speculative decoding.
"""
block_table = self.block_tables[seq.seq_id]
seq_len = seq.data.get_len()
num_required_blocks = cdiv(seq_len + num_new_tokens, self.block_size)
while len(block_table) < num_required_blocks:
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# reuse a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
Expand Down
Loading