Skip to content

Commit

Permalink
[Misc][Refactor] Introduce ExecuteModelData (vllm-project#4540)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored May 4, 2024
1 parent 344bf7c commit bc8ad68
Show file tree
Hide file tree
Showing 23 changed files with 359 additions and 515 deletions.
98 changes: 48 additions & 50 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import torch

from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker

from .utils import (assert_logprobs_dict_allclose, create_batch,
create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)

Expand Down Expand Up @@ -105,31 +104,32 @@ def test_same_output_for_single_step():

final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]

multi_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))

single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
multi_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps)
assert len(actual_output) == num_steps
actual_output = actual_output[0]

single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), )[0]
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]

actual_token_ids = [
output.samples[0].output_token for output in actual_output
Expand Down Expand Up @@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)

continuations = [[1] for _ in prompts]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens), )
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)

# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), sample_len=num_steps)
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps)

# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
Expand All @@ -215,16 +216,16 @@ def test_same_output_for_multi_step():

for _ in multi_step_output:

execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)

single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), ))
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))

# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
Expand Down Expand Up @@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
) for _ in range(k)
], True

execute_model_data, _, _ = create_batch(batch_size, k)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)

proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
proposal_len=k,
)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
max_proposal_len=prompt_len + k - 1,
)

execute_model_data, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)

proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
proposal_len=k,
)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
) for _ in range(k)
], True

execute_model_data, _, _ = create_batch(
seq_group_metadata_list, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)

proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
proposal_len=k,
)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down
64 changes: 29 additions & 35 deletions tests/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .utils import (create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker)
from .utils import create_seq_group_metadata_from_prompts, create_worker


def test_ngram_algo_correctness_for_single_no_match():
Expand Down Expand Up @@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():

proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))

proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():

proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))

proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():

proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))

proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down
Loading

0 comments on commit bc8ad68

Please sign in to comment.