Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/N] Chunked prefill data update #3538

Merged
Merged
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
4d6a05f
[2/n] scheduler changes
rkooo567 Feb 29, 2024
0831f84
[2/n] ip
rkooo567 Feb 29, 2024
f31371f
[2/n]ip
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
b9d93c5
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Feb 29, 2024
42dd362
[2/n] ip
rkooo567 Mar 1, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
e3afc25
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
6141885
[2/n] ip
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
d503a22
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
85760db
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 3, 2024
e40bc45
[2/n] update sequence data
rkooo567 Mar 3, 2024
d85670f
[2/n] add prefill range apis
rkooo567 Mar 3, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
0ca1284
add data.
rkooo567 Mar 3, 2024
2487bda
ip
rkooo567 Mar 3, 2024
81151e8
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
c5f3a0d
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9bbb04e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
5e47c1e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
c66ec36
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
c067a4c
working.
rkooo567 Mar 11, 2024
e1f244a
clean up.
rkooo567 Mar 11, 2024
d09eaf5
.
rkooo567 Mar 11, 2024
4a8ab3c
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
a08e65e
Merge branch 'main' into 1dquery
rkooo567 Mar 11, 2024
d9532f8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
93a7b90
.
rkooo567 Mar 12, 2024
b4b94c6
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
647d8cc
.
rkooo567 Mar 12, 2024
65ac6ce
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
b2f4b3e
ip
rkooo567 Mar 12, 2024
cc8419f
.
rkooo567 Mar 12, 2024
76e7ca8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
d3d0336
Merge branch 'main' into 1dquery
rkooo567 Mar 15, 2024
11ec167
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 15, 2024
3cb8093
ip addressing comments.
rkooo567 Mar 16, 2024
5391129
Alibi slopes working now.
rkooo567 Mar 18, 2024
6b04443
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
fe344f6
add new fieflds
rkooo567 Mar 18, 2024
e619c4e
Flash attn works now
rkooo567 Mar 18, 2024
9c86aa3
Linting
rkooo567 Mar 18, 2024
5b4aa09
temporary
rkooo567 Mar 18, 2024
03dd155
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
4cced78
fix tests
rkooo567 Mar 18, 2024
cdb7a2c
Fixed
rkooo567 Mar 18, 2024
276be06
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
d87b651
Pass unit tests.
rkooo567 Mar 18, 2024
2c18896
experiment
rkooo567 Mar 18, 2024
b46f902
.
rkooo567 Mar 18, 2024
07b22f8
.
rkooo567 Mar 18, 2024
9bd7ea1
.
rkooo567 Mar 18, 2024
c55402f
trial
rkooo567 Mar 18, 2024
a13cf7e
remove --fork
rkooo567 Mar 18, 2024
c5c5581
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
ec91304
fixed
rkooo567 Mar 19, 2024
4977e53
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 19, 2024
4a54688
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
2e6e919
Addressed code review.
rkooo567 Mar 19, 2024
1f6f6b0
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
ac7828c
revert removing forked
rkooo567 Mar 19, 2024
3d7f1a1
done
rkooo567 Mar 19, 2024
bcdd74a
Merge branch 'main' into 1dquery
rkooo567 Mar 20, 2024
fa3ce4e
final code review.
rkooo567 Mar 20, 2024
a83b235
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 20, 2024
7205ef9
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 21, 2024
8bc0af5
.
rkooo567 Mar 21, 2024
97bcb6f
ip
rkooo567 Mar 21, 2024
df34350
working except tests.
rkooo567 Mar 21, 2024
e70e03d
.
rkooo567 Mar 21, 2024
f89f428
ip
rkooo567 Mar 21, 2024
bf02f8e
done
rkooo567 Mar 21, 2024
ad43095
done
rkooo567 Mar 21, 2024
16b6196
Addressed code review.
rkooo567 Mar 22, 2024
916abc8
merge conflict fixed
rkooo567 Mar 25, 2024
5002e61
update
rkooo567 Mar 25, 2024
80f51ea
test fix
rkooo567 Mar 25, 2024
3cc5e99
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 25, 2024
fa7ba35
lint
rkooo567 Mar 25, 2024
51cf7f2
fix broken tests.
rkooo567 Mar 25, 2024
cdee1c6
.
rkooo567 Mar 26, 2024
16e3a7d
done
rkooo567 Mar 26, 2024
e0d301c
remove num chunked prefill from seq group metadata
rkooo567 Mar 27, 2024
5e0f87e
change apis
rkooo567 Mar 27, 2024
6e72648
cleaned
rkooo567 Mar 27, 2024
4f869be
now working
rkooo567 Mar 27, 2024
4f63c57
update with new apis
rkooo567 Mar 27, 2024
5c3abf4
working!
rkooo567 Mar 27, 2024
66f3fcf
fixed
rkooo567 Mar 27, 2024
9c12d8e
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 27, 2024
9d4b65c
Addressed code review.
rkooo567 Mar 28, 2024
54a58b2
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 28, 2024
9bdb9dc
fix tests.
rkooo567 Mar 28, 2024
88126a9
fixed a bug
rkooo567 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def main(args: argparse.Namespace):
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
block_size=args.block_size,
max_chunked_prefill_len=args.max_chunked_prefill_len,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)

Expand Down Expand Up @@ -58,10 +60,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
if args.verbose:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: "
f"{generated_text!r}")
latency = end_time - start_time
return latency

Expand Down Expand Up @@ -146,6 +154,17 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--block-size',
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument('--use-sample',
action='store_true',
help='use sample input instead of dummy input')
parser.add_argument('--verbose',
action='store_true',
help='print generated text')
parser.add_argument('--max-chunked-prefill-len', type=int, default=-1)
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
type=int,
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--block-size",
type=int,
choices=[16, 32, 256],
default=16)
parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype",
type=str,
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
max_chunked_prefill_len: int = -1,
max_num_batched_tokens: int = 4096,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -176,6 +179,9 @@ def __init__(
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
max_chunked_prefill_len=max_chunked_prefill_len,
max_num_batched_tokens=max_num_batched_tokens,
**kwargs,
)

Expand Down
69 changes: 67 additions & 2 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def test_scheduler_schedule_simple():
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)

# Add seq groups to scheduler.
running: List[SequenceGroup] = []
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
Expand Down Expand Up @@ -131,6 +129,73 @@ def test_scheduler_schedule_preempt_abort():
assert scheduler.get_num_unfinished_seq_groups() == 1


def test_scheduler_schedule_chunked_prefill():
block_size = 4
num_seq_group = 2
max_model_len = 16
max_chunked_prefill_len = 2
scheduler_config = SchedulerConfig(
64,
num_seq_group,
max_model_len,
max_chunked_prefill_len=max_chunked_prefill_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)

# Add seq groups to scheduler.
seq_groups: List[SequenceGroup] = []
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group)
seq_groups.append(seq_group)

# Schedule chunk prefill. Only the first seq_group should be scheduled.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups[:1])
assert seq_groups[0].get_num_unprefilled() == 2
assert seq_groups[1].get_num_unprefilled() == 4
assert out.num_batched_tokens == 2
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
assert seq_group_meta[0].request_id == "0"
assert seq_group_meta[0].is_chunked_prefill
assert seq_group_meta[0].is_prompt

# Schedule chunk prefill. Still Only the first seq_group should be
# scheduled.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups[:1])
assert seq_groups[0].get_num_unprefilled() == 0
assert seq_groups[1].get_num_unprefilled() == 4
assert out.num_batched_tokens == 2
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
assert seq_group_meta[0].request_id == "0"
assert not seq_group_meta[0].is_chunked_prefill
assert seq_group_meta[0].is_prompt

# Schedule chunk prefill. This time the second seq_group should be selected
# for chunk prefill, and the first seq_group should be select for decoding.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(seq_groups)
assert seq_groups[0].get_num_unprefilled() == 0
assert seq_groups[1].get_num_unprefilled() == 2
assert out.num_batched_tokens == 3
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 2
assert seq_group_meta[0].request_id == "1"
assert seq_group_meta[0].is_chunked_prefill
assert seq_group_meta[0].is_prompt
assert seq_group_meta[1].request_id == "0"
assert not seq_group_meta[1].is_chunked_prefill
assert not seq_group_meta[1].is_prompt


def test_scheduler_max_seqs():
block_size = 4
num_seq_group = 4
Expand Down
3 changes: 3 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_chunked_prefill=False,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
Expand Down Expand Up @@ -227,6 +228,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_chunked_prefill=False,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
Expand Down Expand Up @@ -316,6 +318,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_chunked_prefill=False,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def create_seq_group_metadata_from_prompts(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
is_chunked_prefill=False,
seq_data={
i:
SequenceData(
Expand Down
36 changes: 35 additions & 1 deletion tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import pytest
from vllm.sequence import (SequenceData, Sequence, SequenceGroupOutput,
SamplerOutput, SequenceOutput)

from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput

@pytest.fixture(name="sequence")
def create_sequence(seq_len: int, block_size: int) -> Sequence:
return Sequence(
seq_id=0,
prompt="",
prompt_token_ids=list(range(seq_len)),
block_size=block_size,
)


@pytest.fixture
Expand Down Expand Up @@ -48,3 +58,27 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3


def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
assert seq_data.get_prefill_range() == (0, 0)
assert seq_data.get_num_unprefilled() == 4

# advance by 2
assert seq_data.advance_prefill_range(2) == 2
assert seq_data.get_num_unprefilled() == 2
assert seq_data.get_prefill_range() == (0, 2)

# advance range by 3 even though there are only 2 unprefilled tokens
assert seq_data.advance_prefill_range(3) == 2
assert seq_data.get_num_unprefilled() == 0
assert seq_data.get_prefill_range() == (2, 4)

# following advances should not change anything
assert seq_data.advance_prefill_range(2) == 0
assert seq_data.get_num_unprefilled() == 0
assert seq_data.get_prefill_range() == (4, 4)

# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
1 change: 1 addition & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_prepare_prompt():
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_chunked_prefill=False,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
Expand Down
10 changes: 9 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,18 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_chunked_prefill_len: The maximum length of tokens for prefill
requests. Longer requests will be chunked into multiple chunks.
-1 means no chunking (disabled). This features is only supported
for flash style attention.
"""

def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_chunked_prefill_len: int = -1,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -551,10 +556,13 @@ def __init__(
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.chunked_prefill_enabled = max_chunked_prefill_len != -1
self.max_chunked_prefill_len = max_chunked_prefill_len
self._verify_args()

def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
if self.max_num_batched_tokens < self.max_model_len and \
not self.chunked_prefill_enabled:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
Expand Down
2 changes: 2 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _schedule(self) -> SchedulerOutputs:
curr_loras.add(lora_int_id)
self.waiting.popleft()
self._allocate(seq_group)
seq_group.advance_prefill_range(num_prompt_tokens)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
Expand Down Expand Up @@ -373,6 +374,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run,
is_chunked_prefill=False,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
Expand Down
20 changes: 17 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
max_chunked_prefill_len: int = -1

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -305,6 +306,17 @@ def add_cli_args(
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
parser.add_argument(
'--max-chunked-prefill-len',
type=int,
default=-1,
help='max number of prefill tokens allowed in chunked prefill'
', -1 means no limit')
parser.add_argument(
'--max-num-prompt-seqs',
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=1024,
help='max number of prompt sequences allowed in prefill')
return parser

@classmethod
Expand Down Expand Up @@ -340,9 +352,11 @@ def create_engine_configs(
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
max_chunked_prefill_len=self.max_chunked_prefill_len)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def step(self) -> List[RequestOutput]:
>>> break
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class InputMetadata:
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The number of chunked prefill sequences in the batch.
num_chunked_prefill: int
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is same as #3539

return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)

Expand Down
Loading
Loading