@@ -61,14 +61,15 @@ def __init__(
6161 cache_config : CacheConfig ,
6262 ) -> None :
6363 self .scheduler_config = scheduler_config
64- self .cache_config = cache_config
64+ # self.cache_config = cache_config
6565
6666 self .prompt_limit = min (self .scheduler_config .max_model_len ,
6767 self .scheduler_config .max_num_batched_tokens )
6868
6969 # Instantiate the scheduling policy.
7070 self .policy = PolicyFactory .get_policy (policy_name = "fcfs" )
71- # # Create the block space manager.
71+ # Create the block space manager.
72+ # CO(gc): disable the block_manager
7273 # self.block_manager = BlockSpaceManager(
7374 # block_size=self.cache_config.block_size,
7475 # num_gpu_blocks=self.cache_config.num_gpu_blocks,
@@ -270,6 +271,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
270271 for seq_group in scheduler_outputs .scheduled_seq_groups :
271272 seq_data : Dict [int , List [SequenceData ]] = {}
272273 block_tables : Dict [int , List [int ]] = {}
274+
273275 for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
274276 seq_id = seq .seq_id
275277 seq_data [seq_id ] = seq .data
@@ -391,3 +393,292 @@ def _swap_out(
391393 blocks_to_swap_out .update (mapping )
392394 for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
393395 seq .status = SequenceStatus .SWAPPED
396+
397+
398+
399+
400+ class FixedWindowScheduler :
401+
402+ def __init__ (
403+ self ,
404+ scheduler_config : SchedulerConfig ,
405+ cache_config : CacheConfig ,
406+ ) -> None :
407+ self .scheduler_config = scheduler_config
408+ #self.cache_config = cache_config
409+
410+ self .prompt_limit = min (self .scheduler_config .max_model_len ,
411+ self .scheduler_config .max_num_batched_tokens )
412+
413+ # Instantiate the scheduling policy.
414+ self .policy = PolicyFactory .get_policy (policy_name = "fcfs" )
415+
416+ # Sequence groups in the WAITING state.
417+ self .waiting : List [SequenceGroup ] = []
418+ # Sequence groups in the RUNNING state.
419+ self .running : List [SequenceGroup ] = []
420+
421+ def add_seq_group (self , seq_group : SequenceGroup ) -> None :
422+ # Add sequence groups to the waiting queue.
423+ self .waiting .append (seq_group )
424+
425+ def abort_seq_group (self , request_id : Union [str , Iterable [str ]]) -> None :
426+ if isinstance (request_id , str ):
427+ request_id = (request_id , )
428+ request_ids = set (request_id )
429+ for state_queue in [self .waiting , self .running ]:
430+ # We need to reverse the list as we are removing elements
431+ # from it as we iterate over it. If we don't do it,
432+ # indices will get messed up and we will skip over elements.
433+ for seq_group in reversed (state_queue ):
434+ if seq_group .request_id in request_ids :
435+ # Remove the sequence group from the state queue.
436+ state_queue .remove (seq_group )
437+ for seq in seq_group .get_seqs ():
438+ if seq .is_finished ():
439+ continue
440+ seq .status = SequenceStatus .FINISHED_ABORTED
441+ self .free_seq (seq )
442+ request_ids .remove (seq_group .request_id )
443+ if not request_ids :
444+ return
445+
446+ def has_unfinished_seqs (self ) -> bool :
447+ return self .waiting or self .running
448+
449+ def get_num_unfinished_seq_groups (self ) -> int :
450+ return len (self .waiting ) + len (self .running )
451+
452+ def _schedule (self ) -> SchedulerOutputs :
453+
454+ # Fix the current time.
455+ now = time .monotonic ()
456+
457+ ignored_seq_groups : List [SequenceGroup ] = []
458+ scheduled : List [SequenceGroup ] = []
459+ # The total number of sequences on the fly, including the
460+ # requests in the generation phase.
461+ num_curr_seqs = sum (seq_group .get_max_num_running_seqs ()
462+ for seq_group in self .running )
463+ num_batched_tokens = 0
464+
465+ if self .waiting :
466+ # We restrict how many requests that can be run using these three arguments
467+ # Co(gc): If there are waiting requests, we will just try to add it into the running state if not exceeds the stage
468+ # Co(gc): prefilled requests are prioritied over decoding stage requests
469+ while self .waiting :
470+ seq_group = self .waiting [0 ]
471+
472+ assert seq_group .num_seqs () == 1 , (
473+ "Waiting sequence group should have only one prompt "
474+ "sequence." )
475+ num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
476+ if num_prompt_tokens > self .prompt_limit :
477+ logger .warning (
478+ f"Input prompt ({ num_prompt_tokens } tokens) is too long"
479+ f" and exceeds limit of { self .prompt_limit } " )
480+ for seq in seq_group .get_seqs ():
481+ seq .status = SequenceStatus .FINISHED_IGNORED
482+ ignored_seq_groups .append (seq_group )
483+ self .waiting .pop (0 )
484+ continue
485+
486+ # If the sequence group cannot be allocated, stop.
487+ # if not self.block_manager.can_allocate(seq_group):
488+ # break
489+
490+ # If the number of batched tokens exceeds the limit, stop.
491+ if (num_batched_tokens + num_prompt_tokens >
492+ self .scheduler_config .max_num_batched_tokens ):
493+ break
494+
495+ # The total number of sequences in the RUNNING state should not
496+ # exceed the maximum number of sequences.
497+ num_new_seqs = seq_group .get_max_num_running_seqs ()
498+ if (num_curr_seqs + num_new_seqs >
499+ self .scheduler_config .max_num_seqs ):
500+ break
501+
502+ seq_group = self .waiting .pop (0 )
503+ for seq in seq_group .get_seqs ():
504+ seq .status = SequenceStatus .RUNNING
505+ #self._allocate(seq_group)
506+ self .running .append (seq_group )
507+ num_batched_tokens += num_prompt_tokens
508+ num_curr_seqs += num_new_seqs
509+ scheduled .append (seq_group )
510+
511+ print ("We have waited sequence_groups" )
512+
513+ scheduler_outputs = SchedulerOutputs (
514+ scheduled_seq_groups = scheduled ,
515+ prompt_run = True ,
516+ num_batched_tokens = num_batched_tokens ,
517+ blocks_to_swap_in = {},
518+ blocks_to_swap_out = {},
519+ blocks_to_copy = {},
520+ ignored_seq_groups = ignored_seq_groups ,
521+ )
522+ return scheduler_outputs
523+
524+ # Now consider all the requests in decoding stage
525+ self .running = self .policy .sort_by_priority (now , self .running )
526+
527+ # Each sequence in the generation phase only takes one token slot.
528+ # Therefore, the number of batched tokens is equal to the number of
529+ # sequences in the RUNNING state.
530+ num_batched_tokens = sum (
531+ seq_group .num_seqs (status = SequenceStatus .RUNNING )
532+ for seq_group in self .running )
533+
534+ scheduler_outputs = SchedulerOutputs (
535+ scheduled_seq_groups = self .running ,
536+ prompt_run = False ,
537+ num_batched_tokens = num_batched_tokens ,
538+ blocks_to_swap_in = {},
539+ blocks_to_swap_out = {},
540+ blocks_to_copy = {},
541+ ignored_seq_groups = [],
542+ )
543+ return scheduler_outputs
544+
545+ def schedule (self ) -> Tuple [List [SequenceGroupMetadata ], SchedulerOutputs ]:
546+ # Schedule sequence groups.
547+ # This function call changes the internal states of the scheduler
548+ # such as self.running, self.swapped, and self.waiting.
549+ scheduler_outputs = self ._schedule ()
550+
551+ # Create input data structures.
552+ seq_group_metadata_list : List [SequenceGroupMetadata ] = []
553+ for seq_group in scheduler_outputs .scheduled_seq_groups :
554+ seq_data : Dict [int , List [SequenceData ]] = {}
555+ block_tables : Dict [int , List [int ]] = {}
556+ print ("Here we print the length of the seq_groups" )
557+ print (len (seq_group .get_seqs ()))
558+ print ("The following sequences are scheduled" )
559+ for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
560+ seq_id = seq .seq_id
561+ seq_data [seq_id ] = seq .data
562+ #block_tables[seq_id] = self.block_manager.get_block_table(seq)
563+
564+ seq_group_metadata = SequenceGroupMetadata (
565+ request_id = seq_group .request_id ,
566+ is_prompt = scheduler_outputs .prompt_run ,
567+ seq_data = seq_data ,
568+ sampling_params = seq_group .sampling_params ,
569+ block_tables = block_tables ,
570+ )
571+ print (seq_group_metadata .seq_data .keys ())
572+ seq_group_metadata_list .append (seq_group_metadata )
573+ return seq_group_metadata_list , scheduler_outputs
574+
575+ def fork_seq (self , parent_seq : Sequence , child_seq : Sequence ) -> None :
576+ self .block_manager .fork (parent_seq , child_seq )
577+
578+ def free_seq (self , seq : Sequence ) -> None :
579+ self .block_manager .free (seq )
580+
581+ def free_finished_seq_groups (self ) -> None :
582+ for seq_group in self .running :
583+ if seq_group .is_finished ():
584+ print ("A finished seq_group" )
585+ print (seq_group )
586+ self .running = [
587+ seq_group for seq_group in self .running
588+ if not seq_group .is_finished ()
589+ ]
590+
591+ def _allocate (self , seq_group : SequenceGroup ) -> None :
592+ self .block_manager .allocate (seq_group )
593+ for seq in seq_group .get_seqs ():
594+ seq .status = SequenceStatus .RUNNING
595+
596+ def _append_slot (
597+ self ,
598+ seq_group : SequenceGroup ,
599+ blocks_to_copy : Dict [int , List [int ]],
600+ ) -> None :
601+ for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
602+ ret = self .block_manager .append_slot (seq )
603+ if ret is not None :
604+ src_block , dst_block = ret
605+ if src_block in blocks_to_copy :
606+ blocks_to_copy [src_block ].append (dst_block )
607+ else :
608+ blocks_to_copy [src_block ] = [dst_block ]
609+
610+ def _preempt (
611+ self ,
612+ seq_group : SequenceGroup ,
613+ blocks_to_swap_out : Dict [int , int ],
614+ preemption_mode : Optional [PreemptionMode ] = None ,
615+ ) -> None :
616+ # If preemption mode is not specified, we determine the mode as follows:
617+ # We use recomputation by default since it incurs lower overhead than
618+ # swapping. However, when the sequence group has multiple sequences
619+ # (e.g., beam search), recomputation is not currently supported. In
620+ # such a case, we use swapping instead.
621+ # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
622+ # As swapped sequences are prioritized over waiting sequences,
623+ # sequence groups with multiple sequences are implicitly prioritized
624+ # over sequence groups with a single sequence.
625+ # TODO(woosuk): Support recomputation for sequence groups with multiple
626+ # sequences. This may require a more sophisticated CUDA kernel.
627+ if preemption_mode is None :
628+ if seq_group .get_max_num_running_seqs () == 1 :
629+ preemption_mode = PreemptionMode .RECOMPUTE
630+ else :
631+ preemption_mode = PreemptionMode .SWAP
632+ if preemption_mode == PreemptionMode .RECOMPUTE :
633+ self ._preempt_by_recompute (seq_group )
634+ elif preemption_mode == PreemptionMode .SWAP :
635+ self ._preempt_by_swap (seq_group , blocks_to_swap_out )
636+ else :
637+ assert False , "Invalid preemption mode."
638+
639+ def _preempt_by_recompute (
640+ self ,
641+ seq_group : SequenceGroup ,
642+ ) -> None :
643+ seqs = seq_group .get_seqs (status = SequenceStatus .RUNNING )
644+ assert len (seqs ) == 1
645+ for seq in seqs :
646+ seq .status = SequenceStatus .WAITING
647+ self .block_manager .free (seq )
648+ # NOTE: For FCFS, we insert the preempted sequence group to the front
649+ # of the waiting queue.
650+ self .waiting .insert (0 , seq_group )
651+
652+ def _preempt_by_swap (
653+ self ,
654+ seq_group : SequenceGroup ,
655+ blocks_to_swap_out : Dict [int , int ],
656+ ) -> None :
657+ self ._swap_out (seq_group , blocks_to_swap_out )
658+ self .swapped .append (seq_group )
659+
660+ def _swap_in (
661+ self ,
662+ seq_group : SequenceGroup ,
663+ blocks_to_swap_in : Dict [int , int ],
664+ ) -> None :
665+ mapping = self .block_manager .swap_in (seq_group )
666+ blocks_to_swap_in .update (mapping )
667+ for seq in seq_group .get_seqs (status = SequenceStatus .SWAPPED ):
668+ seq .status = SequenceStatus .RUNNING
669+
670+ def _swap_out (
671+ self ,
672+ seq_group : SequenceGroup ,
673+ blocks_to_swap_out : Dict [int , int ],
674+ ) -> None :
675+ if not self .block_manager .can_swap_out (seq_group ):
676+ # FIXME(woosuk): Abort the sequence group instead of aborting the
677+ # entire engine.
678+ raise RuntimeError (
679+ "Aborted due to the lack of CPU swap space. Please increase "
680+ "the swap space to avoid this error." )
681+ mapping = self .block_manager .swap_out (seq_group )
682+ blocks_to_swap_out .update (mapping )
683+ for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
684+ seq .status = SequenceStatus .SWAPPED
0 commit comments