Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][1/N] Chunked Prefill #3106

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
9442e8f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 8, 2024
3da31eb
Merge branch '1dquery' into chunked-prefill-3
rkooo567 Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
in progress
  • Loading branch information
rkooo567 committed Feb 28, 2024
commit 4769a2636392d4ac1f25b2af758d008de4533f88
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ steps:
commands:
- pytest -v -s prefix_caching

- label: Chunked Prefill Test
commands:
- pytest -v -s chunked_prefill

- label: Samplers Test
command: pytest -v -s samplers --forked

Expand Down
40 changes: 37 additions & 3 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

from vllm import LLM, SamplingParams

SAMPLE_PROMPTS = [
"The president of the United States is",
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]


def main(args: argparse.Namespace):
print(args)
Expand Down Expand Up @@ -57,10 +64,24 @@ def run_to_completion(profile_dir: Optional[str] = None):
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
if args.use_sample:
batch = (
SAMPLE_PROMPTS *
(args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size]
outputs = llm.generate(prompts=batch,
sampling_params=sampling_params,
use_tqdm=False)
else:
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
if args.verbose:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
latency = end_time - start_time
return latency

Expand Down Expand Up @@ -145,5 +166,18 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--flash-style',
action='store_true',
help='enable flash attention')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument('--use-sample',
action='store_true',
help='use sample input instead of dummy input')
parser.add_argument('--verbose',
action='store_true',
help='print generated text')
args = parser.parse_args()
main(args)
9 changes: 1 addition & 8 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,12 @@ __global__ void reshape_and_cache_flash_kernel(
scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int64_t* __restrict__ num_tokens, // [1]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t num_tokens_ = num_tokens[0];
const int64_t token_idx = blockIdx.x;
if (token_idx >= num_tokens_) {
return;
}
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
Expand Down Expand Up @@ -323,8 +318,7 @@ void reshape_and_cache_flash(
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
torch::Tensor& num_tokens) // [1]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens_padded = key.size(0);
int num_heads = key.size(1);
Expand All @@ -347,7 +341,6 @@ void reshape_and_cache_flash(
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
num_tokens.data_ptr<int64_t>(),
key_stride,
value_stride,
num_heads,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
flash-attn >= 2.5.0 # Required for chunked prefill.
82 changes: 82 additions & 0 deletions tests/chunked_prefill/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import gc

from typing import List

import pytest
import torch

from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel

MODELS = [
"JackFram/llama-68m",
]

# SANG-TODO Read it from example.txt
TEST_PROMPTS = [
# pylint: disable=line-too-long
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
# Different between page attention and flash attention.
# "Describe the basic components of a neural network and how it can be trained.",
"Write a short story about a robot that dreams for the first time.",
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
]


# TODO(sang): Add chunked prefill parameters.
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
vllm_runner,
model: str,
dtype: str,
max_tokens: int,
) -> None:
""" verify the flash attention has the same output
as page attention """
print("loading page attention models..")
pg_model = vllm_runner(model, dtype=dtype)
expected_outputs = []

print("generating tokens...")
expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens))
print("generating tokens finished")

del pg_model

destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()

flash_attn_model = vllm_runner(
model,
dtype=dtype,
enable_cuda_graph=False,
flash_style=True,
)
flash_attn_output_by_batchs = []
for i in range(10):
prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)]
flash_attn_output_by_batchs.append(
flash_attn_model.generate_greedy(prompts, max_tokens))

del flash_attn_model

destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()

for flash_attn_outputs in flash_attn_output_by_batchs:
for i in range(len(flash_attn_outputs)):
fa_output_ids, fa_output_str = flash_attn_outputs[i]
vllm_output_ids, vllm_output_str = expected_outputs[
i % len(expected_outputs)]
print()
assert fa_output_ids == vllm_output_ids, (
f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}"
f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}"
)
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
flash_style: bool = False,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -175,6 +176,8 @@ def __init__(
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
flash_style=flash_style,
block_size=32,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def pad_key_value(key: torch.Tensor, value: torch.Tensor,
padded_key, padded_value = pad_key_value(key, value, padding)
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache,
value_cache, slot_mapping, num_tokens)
value_cache, slot_mapping)

# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
Expand Down
1 change: 0 additions & 1 deletion tests/kernels/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def test_flash_paged_attention(
scale,
padded_block_table,
padded_context_lens,
block_size,
alibi_slopes,
)

Expand Down
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ModelConfig:
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
flash_style: Enable flash style page attention.
"""

def __init__(
Expand All @@ -79,6 +80,7 @@ def __init__(
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
flash_style: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -93,6 +95,7 @@ def __init__(
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.flash_style = flash_style

if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
Expand Down Expand Up @@ -295,12 +298,14 @@ def __init__(
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
flash_style: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.flash_style = flash_style
self._verify_args()
self._verify_cache_dtype()

Expand All @@ -314,6 +319,15 @@ def _verify_args(self) -> None:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

if self.flash_style:
logger.info("Flash attention enabled.")
if self.block_size < 256:
# Flash style attention only supports block size >=256 for now.
# https://github.com/Dao-AILab/flash-attention/pull/824 will fix it.
raise ValueError(
"Flash style attention only supports block size >= 256. Got"
f"{self.block_size }")

def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
Expand Down
9 changes: 7 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class EngineArgs:
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
flash_style: bool = False

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -271,6 +272,9 @@ def add_cli_args(
choices=["cuda"],
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
parser.add_argument('--flash-style',
action='store_true',
help='use flash attention.')
return parser

@classmethod
Expand All @@ -291,11 +295,12 @@ def create_engine_configs(
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
self.enforce_eager, self.max_context_len_to_capture, self.flash_style)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window())
model_config.get_sliding_window(),
self.flash_style)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
Expand Down
46 changes: 46 additions & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class InputMetadata:
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
num_prompt_tokens: The number of tokens in the prompts. This might
include padding.
num_generation_tokens: The number of tokens in the generation sequences.
This might include padding.
"""

def __init__(
Expand All @@ -27,6 +31,9 @@ def __init__(
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
kv_cache_dtype: str,
# SANG-TODO
# num_prompt_tokens: int,
# num_generation_tokens: int,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
Expand All @@ -43,6 +50,45 @@ def __init__(
# FIXME(woosuk): This is a hack.
self.attn_bias = None

# SANG-TODO
# # Prompt related metadata
# # This value might include padding if CudaGraph is enabled.
# self.num_prompts = len(prompt_lens)
# # This value is the source of truth.
# self.num_prompts_tensor = torch.cuda.IntTensor([self.num_prompts])
# # This value might include padding if CudaGraph is enabled.
# self.num_prompt_tokens = num_prompt_tokens
# self.prompt_lens_tensor = torch.cuda.IntTensor(self.prompt_lens)
# self.max_prompt_len = max(prompt_lens) if prompt_lens else 0

# # Cumulative prompt lengths for each prompt in the input
# # tensor.
# self.cum_prompt_query_lens = torch.zeros(
# self.num_prompts + 1,
# device=self.prompt_lens_tensor.device,
# dtype=torch.int32)
# # Cumulative context lengths.
# self.cum_prompt_context_lens = torch.zeros(
# self.num_prompts + 1,
# device=self.prompt_lens_tensor.device,
# dtype=torch.int32)

# torch.cumsum(self.prompt_lens_tensor,
# dim=0,
# dtype=self.cum_prompt_query_lens.dtype,
# out=self.cum_prompt_query_lens[1:])

# # TODO: this will be different once we support chunked prefills.
# self.cum_prompt_context_lens = self.cum_prompt_query_lens
# self.max_context_len = max(self.max_context_len, self.max_prompt_len)

# # Generation related metadata
# # This value might include padding if CudaGraph is enabled.
# self.num_generation_tokens = num_generation_tokens
# # This is the source of truth for the number of generation tokens.
# self.num_generation_tokens_tensor = torch.cuda.IntTensor(
# [num_generation_tokens])

def __repr__(self) -> str:
return ("InputMetadata("
f"is_prompt={self.is_prompt}, "
Expand Down
Loading