Skip to content

Commit f78c169

Browse files
gc-fuxiangyuT
andauthored
Support to run batched inference in underlying changes (vllm-project#6)
* finish changing scheduler * finish merge * fix model * Fix (vllm-project#5) * fix problems * fix * delete unused params * remove redundant comments --------- Co-authored-by: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com>
1 parent 02b4cac commit f78c169

File tree

8 files changed

+415
-32
lines changed

8 files changed

+415
-32
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Try sending a mocked request to the underlying model execute stage"""
2+
3+
from vllm.engine.async_llm_engine import AsyncLLMEngine
4+
from vllm.engine.arg_utils import AsyncEngineArgs
5+
import pytest
6+
import asyncio
7+
8+
# This is the model to load for workers
9+
MODEL_PATH="/models/vicuna-7b/"
10+
11+
12+
"""
13+
1. Test to start a AsyncLLMEngine, to ensure that all goes well before start serving.
14+
"""
15+
16+
@pytest.mark.asyncio
17+
async def test_model_execution():
18+
# Let's build an engine_args
19+
engine_args = AsyncEngineArgs(model='/models/vicuna-7b/', tokenizer='/models/vicuna-7b/', tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', seed=0, max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, block_size=16, swap_space=16, gpu_memory_utilization=0.9, max_num_batched_tokens=None, max_num_seqs=256, disable_log_stats=False, revision=None, tokenizer_revision=None, quantization=None, engine_use_ray=False, disable_log_requests=True, max_log_len=None)
20+
# Start the engine
21+
engine = AsyncLLMEngine.from_engine_args(engine_args)
22+
23+
engine.start_background_loop()
24+
await asyncio.sleep(5)
25+
26+

vllm/core/scheduler.py

Lines changed: 293 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@ def __init__(
6161
cache_config: CacheConfig,
6262
) -> None:
6363
self.scheduler_config = scheduler_config
64-
self.cache_config = cache_config
64+
#self.cache_config = cache_config
6565

6666
self.prompt_limit = min(self.scheduler_config.max_model_len,
6767
self.scheduler_config.max_num_batched_tokens)
6868

6969
# Instantiate the scheduling policy.
7070
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
71-
# # Create the block space manager.
71+
# Create the block space manager.
72+
# CO(gc): disable the block_manager
7273
# self.block_manager = BlockSpaceManager(
7374
# block_size=self.cache_config.block_size,
7475
# num_gpu_blocks=self.cache_config.num_gpu_blocks,
@@ -270,6 +271,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
270271
for seq_group in scheduler_outputs.scheduled_seq_groups:
271272
seq_data: Dict[int, List[SequenceData]] = {}
272273
block_tables: Dict[int, List[int]] = {}
274+
273275
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
274276
seq_id = seq.seq_id
275277
seq_data[seq_id] = seq.data
@@ -391,3 +393,292 @@ def _swap_out(
391393
blocks_to_swap_out.update(mapping)
392394
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
393395
seq.status = SequenceStatus.SWAPPED
396+
397+
398+
399+
400+
class FixedWindowScheduler:
401+
402+
def __init__(
403+
self,
404+
scheduler_config: SchedulerConfig,
405+
cache_config: CacheConfig,
406+
) -> None:
407+
self.scheduler_config = scheduler_config
408+
#self.cache_config = cache_config
409+
410+
self.prompt_limit = min(self.scheduler_config.max_model_len,
411+
self.scheduler_config.max_num_batched_tokens)
412+
413+
# Instantiate the scheduling policy.
414+
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
415+
416+
# Sequence groups in the WAITING state.
417+
self.waiting: List[SequenceGroup] = []
418+
# Sequence groups in the RUNNING state.
419+
self.running: List[SequenceGroup] = []
420+
421+
def add_seq_group(self, seq_group: SequenceGroup) -> None:
422+
# Add sequence groups to the waiting queue.
423+
self.waiting.append(seq_group)
424+
425+
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
426+
if isinstance(request_id, str):
427+
request_id = (request_id, )
428+
request_ids = set(request_id)
429+
for state_queue in [self.waiting, self.running]:
430+
# We need to reverse the list as we are removing elements
431+
# from it as we iterate over it. If we don't do it,
432+
# indices will get messed up and we will skip over elements.
433+
for seq_group in reversed(state_queue):
434+
if seq_group.request_id in request_ids:
435+
# Remove the sequence group from the state queue.
436+
state_queue.remove(seq_group)
437+
for seq in seq_group.get_seqs():
438+
if seq.is_finished():
439+
continue
440+
seq.status = SequenceStatus.FINISHED_ABORTED
441+
self.free_seq(seq)
442+
request_ids.remove(seq_group.request_id)
443+
if not request_ids:
444+
return
445+
446+
def has_unfinished_seqs(self) -> bool:
447+
return self.waiting or self.running
448+
449+
def get_num_unfinished_seq_groups(self) -> int:
450+
return len(self.waiting) + len(self.running)
451+
452+
def _schedule(self) -> SchedulerOutputs:
453+
454+
# Fix the current time.
455+
now = time.monotonic()
456+
457+
ignored_seq_groups: List[SequenceGroup] = []
458+
scheduled: List[SequenceGroup] = []
459+
# The total number of sequences on the fly, including the
460+
# requests in the generation phase.
461+
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
462+
for seq_group in self.running)
463+
num_batched_tokens = 0
464+
465+
if self.waiting:
466+
# We restrict how many requests that can be run using these three arguments
467+
# Co(gc): If there are waiting requests, we will just try to add it into the running state if not exceeds the stage
468+
# Co(gc): prefilled requests are prioritied over decoding stage requests
469+
while self.waiting:
470+
seq_group = self.waiting[0]
471+
472+
assert seq_group.num_seqs() == 1, (
473+
"Waiting sequence group should have only one prompt "
474+
"sequence.")
475+
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
476+
if num_prompt_tokens > self.prompt_limit:
477+
logger.warning(
478+
f"Input prompt ({num_prompt_tokens} tokens) is too long"
479+
f" and exceeds limit of {self.prompt_limit}")
480+
for seq in seq_group.get_seqs():
481+
seq.status = SequenceStatus.FINISHED_IGNORED
482+
ignored_seq_groups.append(seq_group)
483+
self.waiting.pop(0)
484+
continue
485+
486+
# If the sequence group cannot be allocated, stop.
487+
# if not self.block_manager.can_allocate(seq_group):
488+
# break
489+
490+
# If the number of batched tokens exceeds the limit, stop.
491+
if (num_batched_tokens + num_prompt_tokens >
492+
self.scheduler_config.max_num_batched_tokens):
493+
break
494+
495+
# The total number of sequences in the RUNNING state should not
496+
# exceed the maximum number of sequences.
497+
num_new_seqs = seq_group.get_max_num_running_seqs()
498+
if (num_curr_seqs + num_new_seqs >
499+
self.scheduler_config.max_num_seqs):
500+
break
501+
502+
seq_group = self.waiting.pop(0)
503+
for seq in seq_group.get_seqs():
504+
seq.status = SequenceStatus.RUNNING
505+
#self._allocate(seq_group)
506+
self.running.append(seq_group)
507+
num_batched_tokens += num_prompt_tokens
508+
num_curr_seqs += num_new_seqs
509+
scheduled.append(seq_group)
510+
511+
print("We have waited sequence_groups")
512+
513+
scheduler_outputs = SchedulerOutputs(
514+
scheduled_seq_groups=scheduled,
515+
prompt_run=True,
516+
num_batched_tokens=num_batched_tokens,
517+
blocks_to_swap_in={},
518+
blocks_to_swap_out={},
519+
blocks_to_copy={},
520+
ignored_seq_groups=ignored_seq_groups,
521+
)
522+
return scheduler_outputs
523+
524+
# Now consider all the requests in decoding stage
525+
self.running = self.policy.sort_by_priority(now, self.running)
526+
527+
# Each sequence in the generation phase only takes one token slot.
528+
# Therefore, the number of batched tokens is equal to the number of
529+
# sequences in the RUNNING state.
530+
num_batched_tokens = sum(
531+
seq_group.num_seqs(status=SequenceStatus.RUNNING)
532+
for seq_group in self.running)
533+
534+
scheduler_outputs = SchedulerOutputs(
535+
scheduled_seq_groups=self.running,
536+
prompt_run=False,
537+
num_batched_tokens=num_batched_tokens,
538+
blocks_to_swap_in={},
539+
blocks_to_swap_out={},
540+
blocks_to_copy={},
541+
ignored_seq_groups=[],
542+
)
543+
return scheduler_outputs
544+
545+
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
546+
# Schedule sequence groups.
547+
# This function call changes the internal states of the scheduler
548+
# such as self.running, self.swapped, and self.waiting.
549+
scheduler_outputs = self._schedule()
550+
551+
# Create input data structures.
552+
seq_group_metadata_list: List[SequenceGroupMetadata] = []
553+
for seq_group in scheduler_outputs.scheduled_seq_groups:
554+
seq_data: Dict[int, List[SequenceData]] = {}
555+
block_tables: Dict[int, List[int]] = {}
556+
print("Here we print the length of the seq_groups")
557+
print(len(seq_group.get_seqs()))
558+
print("The following sequences are scheduled")
559+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
560+
seq_id = seq.seq_id
561+
seq_data[seq_id] = seq.data
562+
#block_tables[seq_id] = self.block_manager.get_block_table(seq)
563+
564+
seq_group_metadata = SequenceGroupMetadata(
565+
request_id=seq_group.request_id,
566+
is_prompt=scheduler_outputs.prompt_run,
567+
seq_data=seq_data,
568+
sampling_params=seq_group.sampling_params,
569+
block_tables=block_tables,
570+
)
571+
print(seq_group_metadata.seq_data.keys())
572+
seq_group_metadata_list.append(seq_group_metadata)
573+
return seq_group_metadata_list, scheduler_outputs
574+
575+
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
576+
self.block_manager.fork(parent_seq, child_seq)
577+
578+
def free_seq(self, seq: Sequence) -> None:
579+
self.block_manager.free(seq)
580+
581+
def free_finished_seq_groups(self) -> None:
582+
for seq_group in self.running:
583+
if seq_group.is_finished():
584+
print("A finished seq_group")
585+
print(seq_group)
586+
self.running = [
587+
seq_group for seq_group in self.running
588+
if not seq_group.is_finished()
589+
]
590+
591+
def _allocate(self, seq_group: SequenceGroup) -> None:
592+
self.block_manager.allocate(seq_group)
593+
for seq in seq_group.get_seqs():
594+
seq.status = SequenceStatus.RUNNING
595+
596+
def _append_slot(
597+
self,
598+
seq_group: SequenceGroup,
599+
blocks_to_copy: Dict[int, List[int]],
600+
) -> None:
601+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
602+
ret = self.block_manager.append_slot(seq)
603+
if ret is not None:
604+
src_block, dst_block = ret
605+
if src_block in blocks_to_copy:
606+
blocks_to_copy[src_block].append(dst_block)
607+
else:
608+
blocks_to_copy[src_block] = [dst_block]
609+
610+
def _preempt(
611+
self,
612+
seq_group: SequenceGroup,
613+
blocks_to_swap_out: Dict[int, int],
614+
preemption_mode: Optional[PreemptionMode] = None,
615+
) -> None:
616+
# If preemption mode is not specified, we determine the mode as follows:
617+
# We use recomputation by default since it incurs lower overhead than
618+
# swapping. However, when the sequence group has multiple sequences
619+
# (e.g., beam search), recomputation is not currently supported. In
620+
# such a case, we use swapping instead.
621+
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
622+
# As swapped sequences are prioritized over waiting sequences,
623+
# sequence groups with multiple sequences are implicitly prioritized
624+
# over sequence groups with a single sequence.
625+
# TODO(woosuk): Support recomputation for sequence groups with multiple
626+
# sequences. This may require a more sophisticated CUDA kernel.
627+
if preemption_mode is None:
628+
if seq_group.get_max_num_running_seqs() == 1:
629+
preemption_mode = PreemptionMode.RECOMPUTE
630+
else:
631+
preemption_mode = PreemptionMode.SWAP
632+
if preemption_mode == PreemptionMode.RECOMPUTE:
633+
self._preempt_by_recompute(seq_group)
634+
elif preemption_mode == PreemptionMode.SWAP:
635+
self._preempt_by_swap(seq_group, blocks_to_swap_out)
636+
else:
637+
assert False, "Invalid preemption mode."
638+
639+
def _preempt_by_recompute(
640+
self,
641+
seq_group: SequenceGroup,
642+
) -> None:
643+
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
644+
assert len(seqs) == 1
645+
for seq in seqs:
646+
seq.status = SequenceStatus.WAITING
647+
self.block_manager.free(seq)
648+
# NOTE: For FCFS, we insert the preempted sequence group to the front
649+
# of the waiting queue.
650+
self.waiting.insert(0, seq_group)
651+
652+
def _preempt_by_swap(
653+
self,
654+
seq_group: SequenceGroup,
655+
blocks_to_swap_out: Dict[int, int],
656+
) -> None:
657+
self._swap_out(seq_group, blocks_to_swap_out)
658+
self.swapped.append(seq_group)
659+
660+
def _swap_in(
661+
self,
662+
seq_group: SequenceGroup,
663+
blocks_to_swap_in: Dict[int, int],
664+
) -> None:
665+
mapping = self.block_manager.swap_in(seq_group)
666+
blocks_to_swap_in.update(mapping)
667+
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
668+
seq.status = SequenceStatus.RUNNING
669+
670+
def _swap_out(
671+
self,
672+
seq_group: SequenceGroup,
673+
blocks_to_swap_out: Dict[int, int],
674+
) -> None:
675+
if not self.block_manager.can_swap_out(seq_group):
676+
# FIXME(woosuk): Abort the sequence group instead of aborting the
677+
# entire engine.
678+
raise RuntimeError(
679+
"Aborted due to the lack of CPU swap space. Please increase "
680+
"the swap space to avoid this error.")
681+
mapping = self.block_manager.swap_out(seq_group)
682+
blocks_to_swap_out.update(mapping)
683+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
684+
seq.status = SequenceStatus.SWAPPED

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def create_engine_configs(
185185
self.dtype, self.seed, self.revision,
186186
self.tokenizer_revision, self.max_model_len,
187187
self.quantization)
188+
# gc-TODO: disable cache_config later
188189
cache_config = CacheConfig(
189190
self.block_size, self.gpu_memory_utilization, self.swap_space,
190191
getattr(model_config.hf_config, 'sliding_window', None))

0 commit comments

Comments
 (0)