22
33from array import array
44from dataclasses import dataclass
5- from typing import Dict , List , Optional , Tuple
5+ from typing import Optional
66
77import torch
88
@@ -25,10 +25,10 @@ class SequenceGroupToSample:
2525 # |-- query_len ---|
2626
2727 # Sequence ids for the sequence group in a previous step.
28- seq_ids : List [int ]
28+ seq_ids : list [int ]
2929 sampling_params : SamplingParams
3030 # seq_id -> sequence data.
31- seq_data : Dict [int , SequenceData ]
31+ seq_data : dict [int , SequenceData ]
3232 # The length of the sequence (all tokens seen in the past + new token to
3333 # compute attention) of the sequence group. None if it is in a decode
3434 # stage.
@@ -44,9 +44,9 @@ class SequenceGroupToSample:
4444 is_prompt : bool
4545 # Query token indices from logits. to compute prompt logprob. Empty if
4646 # prompt logprob is not required.
47- prompt_logprob_indices : List [int ]
47+ prompt_logprob_indices : list [int ]
4848 # Sample token indices from logits. Empty if sampling is not required.
49- sample_indices : List [int ]
49+ sample_indices : list [int ]
5050
5151 @property
5252 def do_sample (self ):
@@ -78,7 +78,7 @@ class SamplingMetadataCache:
7878 """Used to cache SamplingMetadata objects between scheduler iterations"""
7979
8080 def __init__ (self ):
81- self ._seq_group_to_sample_cache : Dict [int , PyObjectCache ] = {}
81+ self ._seq_group_to_sample_cache : dict [int , PyObjectCache ] = {}
8282
8383 def get_cached_seq_group_to_sample (self , num_seqs ):
8484 if num_seqs not in self ._seq_group_to_sample_cache :
@@ -130,9 +130,9 @@ def sample(logits):
130130
131131 def __init__ (
132132 self ,
133- seq_groups : List [SequenceGroupToSample ],
133+ seq_groups : list [SequenceGroupToSample ],
134134 selected_token_indices : torch .Tensor ,
135- categorized_sample_indices : Dict [SamplingType , torch .Tensor ],
135+ categorized_sample_indices : dict [SamplingType , torch .Tensor ],
136136 num_prompts : int ,
137137 skip_sampler_cpu_output : bool = False ,
138138 reuse_sampling_tensors : bool = False ,
@@ -146,12 +146,12 @@ def __init__(
146146
147147 @staticmethod
148148 def prepare (
149- seq_group_metadata_list : List [SequenceGroupMetadata ],
150- seq_lens : List [int ],
151- query_lens : List [int ],
149+ seq_group_metadata_list : list [SequenceGroupMetadata ],
150+ seq_lens : list [int ],
151+ query_lens : list [int ],
152152 device : str ,
153153 pin_memory : bool ,
154- generators : Optional [Dict [str , torch .Generator ]] = None ,
154+ generators : Optional [dict [str , torch .Generator ]] = None ,
155155 cache : Optional [SamplingMetadataCache ] = None ,
156156 ) -> "SamplingMetadata" :
157157 (
@@ -195,16 +195,16 @@ def __repr__(self) -> str:
195195
196196
197197def _prepare_seq_groups (
198- seq_group_metadata_list : List [SequenceGroupMetadata ],
199- seq_lens : List [int ],
200- query_lens : List [int ],
198+ seq_group_metadata_list : list [SequenceGroupMetadata ],
199+ seq_lens : list [int ],
200+ query_lens : list [int ],
201201 device : str ,
202- generators : Optional [Dict [str , torch .Generator ]] = None ,
202+ generators : Optional [dict [str , torch .Generator ]] = None ,
203203 cache : Optional [SamplingMetadataCache ] = None ,
204- ) -> Tuple [
205- List [SequenceGroupToSample ],
206- List [int ],
207- Dict [SamplingType , List [int ]],
204+ ) -> tuple [
205+ list [SequenceGroupToSample ],
206+ list [int ],
207+ dict [SamplingType , list [int ]],
208208 int ,
209209]:
210210 """Prepare sequence groups and indices for sampling.
@@ -227,17 +227,17 @@ def _prepare_seq_groups(
227227 num_prompts: Total number of prompts from `seq_group_metadata_list`.
228228 """
229229 # Batched sequence groups for the current model forward stsep.
230- seq_groups : List [SequenceGroupToSample ] = []
230+ seq_groups : list [SequenceGroupToSample ] = []
231231 # A list of token indices to sample/compute logprob. It is used to
232232 # prune the outcome logits from the model for the performance.
233- selected_token_indices : List [int ] = []
233+ selected_token_indices : list [int ] = []
234234 # Used for selected_token_indices.
235235 model_output_idx = 0
236236
237237 # Sampling type -> (
238238 # indices to sample/prompt logprob within pruned output logits,
239239 # indices to sample within pruned logits)
240- categorized_sample_indices : Dict [SamplingType , List [int ]] = {
240+ categorized_sample_indices : dict [SamplingType , list [int ]] = {
241241 t : []
242242 for t in SamplingType
243243 }
@@ -265,9 +265,9 @@ def _prepare_seq_groups(
265265 # If the current seq group is in decode stage, it is None.
266266 seq_len : Optional [int ] = None
267267 query_len : Optional [int ] = None
268- prompt_logprob_indices : List [int ] = (sample_obj .prompt_logprob_indices
268+ prompt_logprob_indices : list [int ] = (sample_obj .prompt_logprob_indices
269269 if cache is not None else [])
270- sample_indices : List [int ] = (sample_obj .sample_indices
270+ sample_indices : list [int ] = (sample_obj .sample_indices
271271 if cache is not None else [])
272272 do_sample = seq_group_metadata .do_sample
273273
@@ -389,16 +389,16 @@ def from_sampling_metadata(
389389 vocab_size : int ,
390390 device : torch .device ,
391391 dtype : torch .dtype ,
392- ) -> Tuple ["SamplingTensors" , bool , bool , bool ]:
393- prompt_tokens : List [array ] = []
394- output_tokens : List [array ] = []
395- top_ks : List [int ] = []
396- temperatures : List [float ] = []
397- top_ps : List [float ] = []
398- min_ps : List [float ] = []
399- presence_penalties : List [float ] = []
400- frequency_penalties : List [float ] = []
401- repetition_penalties : List [float ] = []
392+ ) -> tuple ["SamplingTensors" , bool , bool , bool ]:
393+ prompt_tokens : list [array ] = []
394+ output_tokens : list [array ] = []
395+ top_ks : list [int ] = []
396+ temperatures : list [float ] = []
397+ top_ps : list [float ] = []
398+ min_ps : list [float ] = []
399+ presence_penalties : list [float ] = []
400+ frequency_penalties : list [float ] = []
401+ repetition_penalties : list [float ] = []
402402 do_penalties = False
403403 do_top_p_top_k = False
404404 do_min_p = False
@@ -496,15 +496,15 @@ def from_sampling_metadata(
496496 @classmethod
497497 def from_lists (
498498 cls ,
499- temperatures : List [float ],
500- top_ps : List [float ],
501- top_ks : List [int ],
502- min_ps : List [float ],
503- presence_penalties : List [float ],
504- frequency_penalties : List [float ],
505- repetition_penalties : List [float ],
506- prompt_tokens : List [array ],
507- output_tokens : List [array ],
499+ temperatures : list [float ],
500+ top_ps : list [float ],
501+ top_ks : list [int ],
502+ min_ps : list [float ],
503+ presence_penalties : list [float ],
504+ frequency_penalties : list [float ],
505+ repetition_penalties : list [float ],
506+ prompt_tokens : list [array ],
507+ output_tokens : list [array ],
508508 vocab_size : int ,
509509 device : torch .device ,
510510 dtype : torch .dtype ,
0 commit comments