Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API #4681

Merged
merged 49 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c9fcb26
first checkpoint done
rkooo567 May 8, 2024
de61dbb
refactoring subquery
rkooo567 May 8, 2024
159ea2f
.
rkooo567 May 8, 2024
5833cbb
ip
rkooo567 May 8, 2024
7614ce0
working
rkooo567 May 8, 2024
6744eff
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 9, 2024
7de7f63
.
rkooo567 May 9, 2024
e8a4ea3
working with flash attn
rkooo567 May 9, 2024
64e8fd4
rocm and sdpa
rkooo567 May 9, 2024
ceec66d
working with flash infer
rkooo567 May 9, 2024
1851d59
add flash infer to pipeline
rkooo567 May 9, 2024
5cf1d3e
.
rkooo567 May 9, 2024
21a612a
working.
rkooo567 May 9, 2024
61dec37
fix spec decoding
rkooo567 May 9, 2024
6cad7bc
Fixed model runner test
rkooo567 May 9, 2024
e20a29e
fixed
rkooo567 May 9, 2024
94964ab
fix intel test
rkooo567 May 9, 2024
ff99251
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
1c77e2d
.
rkooo567 May 10, 2024
f929edd
done
rkooo567 May 10, 2024
74683a1
.
rkooo567 May 10, 2024
546735a
fix circular reference.
rkooo567 May 10, 2024
89e5df2
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
d7b2743
working
rkooo567 May 10, 2024
0ed4160
Merge branch 'circular-dep' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
f5af730
fixed spec decoding
rkooo567 May 10, 2024
7e39882
working
rkooo567 May 10, 2024
e02bc5d
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 13, 2024
dd48c00
fix embedding meta
rkooo567 May 13, 2024
bba70f1
ip
rkooo567 May 13, 2024
f76b9ea
improve assert
rkooo567 May 13, 2024
cf1dbbb
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 13, 2024
a281d97
done
rkooo567 May 13, 2024
f6afb05
lint
rkooo567 May 13, 2024
35f64ac
done
rkooo567 May 14, 2024
cc5df57
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
ccf937c
.
rkooo567 May 14, 2024
0951715
works except spec decoding
rkooo567 May 14, 2024
2b2423d
.
rkooo567 May 14, 2024
f1c12f3
.,
rkooo567 May 14, 2024
8a01746
.
rkooo567 May 14, 2024
bf7959a
.
rkooo567 May 14, 2024
b42b43d
.
rkooo567 May 14, 2024
e9a973e
ip
rkooo567 May 14, 2024
4e733e2
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
237e939
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
35e98a0
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 15, 2024
426d99a
.
rkooo567 May 15, 2024
1271556
done
rkooo567 May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
first checkpoint done
  • Loading branch information
rkooo567 committed May 8, 2024
commit c9fcb26d37202fadd9c4a965dcb9ca04520d62b3
5 changes: 3 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

MODELS = [
"facebook/opt-125m",
# "facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
Expand All @@ -16,7 +16,8 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
# @pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("enforce_eager", [False])
def test_models(
hf_runner,
vllm_runner,
Expand Down
8 changes: 5 additions & 3 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
# "meta-llama/Llama-2-7b-hf",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
# @pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("enforce_eager", [False])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
Expand Down
102 changes: 80 additions & 22 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size

def _prepare_prompt(
def _prepare_hybrid_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PreparePromptMetadata:
Expand All @@ -237,14 +237,17 @@ def _prepare_prompt(
seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
decode_only = True

if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()

for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
if seq_group_metadata.is_prompt:
decode_only = False

seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
Expand Down Expand Up @@ -273,20 +276,21 @@ def _prepare_prompt(
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
block_tables.append(computed_block_nums)
# elif self.scheduler_config.chunked_prefill_enabled:
else:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert context_len == 0
block_tables.append([])
# else:
# prefix_block_tables.append([])
# # Right now, prefill start is always 0. However, this
# # assumption can be changed once chunked prefill is introduced.
# assert context_len == 0

# actual prompt lens
context_lens.append(context_len)
Expand Down Expand Up @@ -342,8 +346,59 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

# max_query_len = max(query_lens)
# max_seq_len = max(seq_lens)

# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_query_len = max(query_lens)
max_seq_len = max(seq_lens)
use_captured_graph = (decode_only
and not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)

if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert seq_lens_tensor.shape[0] == len(input_tokens)
assert seq_lens_tensor.shape[0] == len(input_positions)
assert seq_lens_tensor.shape[0] == len(slot_mapping)

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
assert max_query_len > 0

context_lens_tensor = torch.tensor(context_lens,
Expand All @@ -360,14 +415,14 @@ def _prepare_prompt(
multi_modal_input = None

# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
# max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
# block_tables = make_tensor_with_pad(
# prefix_block_tables,
# max_len=max_prompt_block_table_len,
# pad=0,
# dtype=torch.int,
# device=self.device,
# )

# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
Expand Down Expand Up @@ -637,16 +692,19 @@ def prepare_input_tensors(
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
) = self._prepare_hybrid_batch(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
_,
_,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
_,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
) = self._prepare_hybrid_batch(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
Expand Down
Loading