Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Speculative decoding using a draft model #2188

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 148 additions & 48 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,72 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time
from pathlib import Path
from typing import Optional

import os
import numpy as np
import torch
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.anyscale.lora.utils import LoRARequest

SAMPLE_PROMPTS = [
"The president of the United States is",
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]


def add_lora(llm, batch_size):
LORA_FILE1 = "/mnt/local_storage/lora/"
for i in range(batch_size):
lora_request = LoRARequest(lora_id=f"lora_{i + 1}",
lora_int_id=i + 1,
lora_local_path=LORA_FILE1)
assert llm.llm_engine.add_lora(lora_request)


def main(args: argparse.Namespace):
print(args)

# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=40960,
trust_remote_code=args.trust_remote_code,
load_format="dummy" if args.use_dummy_weights else "auto",
enable_lora=args.enable_lora,
enable_cuda_graph=args.enable_cuda_graph,
cuda_graph_cache_size=args.cuda_graph_cache_size,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
flash_style=args.flash_style,
max_chunked_prefill_len=args.max_chunked_prefill_len,
max_num_prompt_seqs=args.max_num_prompt_seqs,
block_size=32 if args.flash_style else args.block_size,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
speculative_model_uses_tp_1=args.speculative_model_uses_tp_1,
ray_workers_use_nsight=args.run_with_nsight,
disable_shared_memory=args.disable_shared_memory,
worker_use_ray=args.worker_use_ray,
disable_log_stats=not args.log_engine_stats,
)

if args.enable_lora:
lora_request = add_lora(llm, args.batch_size)
else:
lora_request = None

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
temperature=0 if args.use_sample else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
Expand All @@ -37,44 +75,75 @@ def main(args: argparse.Namespace):
print(sampling_params)
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
def run_to_completion():
start_time = time.perf_counter()

if args.use_sample:
batch = (
SAMPLE_PROMPTS *
(args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size]
outputs = llm.generate(prompts=batch,
sampling_params=sampling_params,
use_tqdm=False,
lora_request=lora_request)
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
latency = end_time - start_time
return latency
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False,
lora_request=lora_request)

end_time = time.perf_counter()

if args.verbose:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

latency = end_time - start_time
return latency

if args.profile and args.enable_cuda_graph:
# Workaround to enable profiling cuda graphs.
# https://github.com/pytorch/pytorch/issues/75504#issuecomment-1467065935
llm.llm_engine.start_profile(
profile_ray_workers=args.profile_ray_workers)
llm.llm_engine.stop_profile(
profile_ray_workers=args.profile_ray_workers)

print("Warming up...")
run_to_completion(profile_dir=None)
run_to_completion()

if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir)
return
model_name = args.model.replace("/", "-")
profile_logdir_name = os.path.join(
args.profile_logdir,
f"{model_name}_tp-{args.tensor_parallel_size}_input-len{args.input_len}_output-len{args.output_len}_batch-size{args.batch_size}"
.lstrip("-"))
llm.llm_engine.start_profile(
profile_ray_workers=args.profile_ray_workers,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
profile_logdir_name),
with_stack=True)

# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies.append(run_to_completion())
print(f'Avg latency: {np.mean(latencies)} seconds')
print(
f'Avg ITL: {1000*np.mean(latencies)/args.output_len:.02f} milliseconds'
)
print(f'Peak Cuda memory: {torch.cuda.max_memory_allocated()}')

if args.profile:
llm.llm_engine.stop_profile(
profile_ray_workers=args.profile_ray_workers, )


if __name__ == '__main__':
Expand All @@ -85,12 +154,12 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--n',
type=int,
default=1,
Expand All @@ -103,6 +172,24 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--enable-lora',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--enable-cuda-graph',
action='store_true',
help='enable cuda graph for decoding')
parser.add_argument('--cuda-graph-cache-size',
type=int,
default=200,
help='number of cuda graphs to cache')
parser.add_argument('--use-dummy-weights',
action='store_true',
help='use-dummy-weights')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--speculative-model-uses-tp-1',
action='store_true',
help='speculative model uses tp1')
parser.add_argument(
'--dtype',
type=str,
Expand All @@ -112,20 +199,33 @@ def run_to_completion(profile_dir: Optional[str] = None):
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
parser.add_argument('--run-with-nsight', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--profile-logdir', type=str, default=None)
parser.add_argument('--profile-ray-workers', action='store_true')
parser.add_argument('--max-chunked-prefill-len', type=int, default=-1)
parser.add_argument('--max-num-prompt-seqs', type=int, default=1000)
parser.add_argument('--flash-style',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
help='enable flash attention')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument('--use-sample',
action='store_true',
help='use sample input instead of dummy input')
parser.add_argument('--disable-shared-memory',
action='store_true',
help='disable shared memory')
parser.add_argument('--verbose',
action='store_true',
help='print generated text')
parser.add_argument('--log-engine-stats',
action='store_true',
help='log engine stats')
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray worker')
args = parser.parse_args()
main(args)
Empty file added tests/anyscale/__init__.py
Empty file.
67 changes: 67 additions & 0 deletions tests/anyscale/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

import gc
import json
import logging
import os

import boto3
import ray
import torch

from vllm.model_executor.parallel_utils.parallel_state import \
destroy_model_parallel

ENV_TOKEN_OVERRIDES = os.getenv("AVIARY_ENV_AWS_SECRET_NAME",
"aviary/env_overrides")

logger = logging.getLogger(__name__)


def cleanup():
# Revert to torch default after vllm modifications
torch.backends.cuda.matmul.allow_tf32 = False
torch.set_default_dtype(torch.float32)
destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()


# Copied from aviary
class SecretManager:

def __init__(self, secret_name: str = ENV_TOKEN_OVERRIDES):
self.secret_overrides = self.get_all_secrets(secret_name)

def get_all_secrets(self, secret_name: str):
try:
aws_region_name = os.getenv("AWS_REGION", "us-west-2")

# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(service_name="secretsmanager",
region_name=aws_region_name)
get_secret_value_response = client.get_secret_value(
SecretId=secret_name)

# Decrypts secret using the associated KMS key.
secret = get_secret_value_response["SecretString"]

secret_dict = json.loads(secret)
return secret_dict
except Exception as e:
print(
f"Unable to load env override secrets from {secret_name}. Using default secrets from env. {e}"
)
return {}

def override_secret(self, env_var_name: str, set_in_env=True):
# First read from env var, then from aws secrets
secret = os.getenv(env_var_name,
self.secret_overrides.get(env_var_name))
if secret is None:
print(f"Secret {env_var_name} was not found.")
elif set_in_env:
os.environ[env_var_name] = secret
print(f"Secret {env_var_name} was set in the env.")
return secret
3 changes: 2 additions & 1 deletion tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

class AsyncLLMEngineWithStats(AsyncLLMEngine):

# pylint: disable=redefined-outer-name
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0

async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
await super().abort(request_id)

def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
Expand Down
Loading
Loading