1+ from dataclasses import dataclass
2+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Type , Union
3+
4+ import torch
5+
6+ from vllm .config import (CacheConfig , DeviceConfig , LoadConfig ,
7+ ModelConfig , ParallelConfig ,
8+ SchedulerConfig )
9+ from vllm .logger import init_logger
10+ from vllm .model_executor .layers .sampler import SamplerOutput
11+ from vllm .model_executor .model_loader .tt_loader import TTModelLoader
12+ from vllm .sequence import IntermediateTensors , SequenceGroupMetadata , Logprob , SequenceOutput , CompletionSequenceGroupOutput
13+ from vllm .worker .model_runner_base import ModelRunnerBase , ModelRunnerInputBase
14+
15+ if TYPE_CHECKING :
16+ from vllm .attention .backends .abstract import AttentionBackend
17+
18+ logger = init_logger (__name__ )
19+
20+
21+ @dataclass (frozen = True )
22+ class TTModelInput (ModelRunnerInputBase ):
23+ """
24+ Used by the TTModelRunner.
25+ """
26+ input_tokens : Optional [torch .Tensor ] = None
27+ input_positions : Optional [torch .Tensor ] = None
28+ prompt_lens : Optional [torch .Tensor ] = None
29+ seq_groups : Optional [List [List [int ]]] = None
30+
31+ def as_broadcastable_tensor_dict (
32+ self ) -> Dict [str , Union [int , torch .Tensor ]]:
33+ tensor_dict = {
34+ "input_tokens" : self .input_tokens ,
35+ "input_positions" : self .input_positions ,
36+ "prompt_lens" : self .prompt_lens ,
37+ "seq_groups" : self .seq_groups ,
38+ }
39+
40+ return tensor_dict
41+
42+ @classmethod
43+ def from_broadcasted_tensor_dict (
44+ cls : Type ["TTModelInput" ],
45+ tensor_dict : Dict [str , Any ],
46+ ) -> "TTModelInput" :
47+ return cls (** tensor_dict )
48+
49+
50+ class TTModelRunner (ModelRunnerBase [TTModelInput ]):
51+
52+ def __init__ (
53+ self ,
54+ model_config : ModelConfig ,
55+ parallel_config : ParallelConfig ,
56+ scheduler_config : SchedulerConfig ,
57+ device_config : DeviceConfig ,
58+ cache_config : CacheConfig ,
59+ load_config : LoadConfig ,
60+ ):
61+ self .model_config = model_config
62+ self .parallel_config = parallel_config
63+ self .scheduler_config = scheduler_config
64+ # Currently, TT worker doesn't support chunked prefill.
65+ assert self .scheduler_config .chunked_prefill_enabled is False
66+ self .device_config = device_config
67+ self .cache_config = cache_config
68+ self .load_config = load_config
69+
70+ self .device = self .device_config .device
71+
72+ self .sliding_window = model_config .get_sliding_window ()
73+ self .block_size = cache_config .block_size
74+
75+ def load_model (self ) -> None :
76+ # Note: using custom TT loader instead of selecting from default vllm loaders
77+ loader = TTModelLoader (self .load_config )
78+ self .model = loader .load_model (model_config = self .model_config ,
79+ device_config = self .device_config ,
80+ parallel_config = self .parallel_config ,
81+ scheduler_config = self .scheduler_config ,
82+ cache_config = self .cache_config
83+ )
84+
85+ def make_model_input_from_broadcasted_tensor_dict (
86+ self ,
87+ tensor_dict : Dict [str , Any ],
88+ ) -> TTModelInput :
89+ return TTModelInput .from_broadcasted_tensor_dict (
90+ tensor_dict ,
91+ )
92+
93+ def prepare_model_input (
94+ self ,
95+ seq_group_metadata_list : List [SequenceGroupMetadata ],
96+ virtual_engine : int = 0 ,
97+ finished_requests_ids : Optional [List [str ]] = None
98+ ) -> TTModelInput :
99+
100+ # NOTE: We assume that all sequences in the group are all prompts or
101+ # all decodes.
102+ is_prompt = seq_group_metadata_list [0 ].is_prompt # prefill if True, otherwise decode
103+ assert all (x .is_prompt == is_prompt for x in seq_group_metadata_list ), "Currently only supporting all prefills or all decodes in seq group"
104+
105+ batch_size = len (seq_group_metadata_list )
106+ assert batch_size > 0
107+
108+ input_tokens : List [int ] = []
109+ input_positions : List [int ] = []
110+ prompt_lens : List [int ] = []
111+
112+ for seq_group_metadata in seq_group_metadata_list :
113+ seq_ids = list (seq_group_metadata .seq_data .keys ())
114+ assert len (seq_ids ) == 1 # Only support one sequence per request group
115+ seq_id = seq_ids [0 ]
116+
117+ seq_data = seq_group_metadata .seq_data [seq_id ]
118+
119+ if is_prompt :
120+ # tokens
121+ prompt_tokens = seq_data .get_token_ids ()
122+ input_tokens .extend (prompt_tokens )
123+
124+ # positions
125+ prompt_len = len (prompt_tokens )
126+ prompt_lens .append (prompt_len )
127+ input_positions .extend (list (range (prompt_len )))
128+ else :
129+ # tokens
130+ generation_token = seq_data .get_last_token_id ()
131+ input_tokens .append (generation_token )
132+
133+ # positions
134+ position = seq_data .get_len () - 1
135+ input_positions .append (position )
136+
137+ # TODO: Get block table using seq_group_metadata.block_tables[seq_id]
138+
139+ input_tokens = torch .tensor (input_tokens , dtype = torch .int32 , device = "cpu" )
140+ input_positions = torch .tensor (input_positions , dtype = torch .int32 , device = "cpu" )
141+ if is_prompt :
142+ prompt_lens = torch .tensor (prompt_lens ,
143+ dtype = torch .int32 ,
144+ device = "cpu" )
145+ else :
146+ prompt_lens = None
147+
148+ seq_groups = [
149+ list (metadata .seq_data .keys ())
150+ for metadata in seq_group_metadata_list
151+ ]
152+
153+ return TTModelInput (input_tokens , input_positions , prompt_lens , seq_groups )
154+
155+ @torch .no_grad ()
156+ def execute_model (
157+ self ,
158+ model_input : TTModelInput ,
159+ kv_caches : List [torch .Tensor ],
160+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
161+ num_steps : int = 1 ,
162+ ) -> Optional [List [SamplerOutput ]]:
163+ if num_steps > 1 :
164+ raise ValueError (
165+ "TT worker does not support multi-step execution." )
166+
167+ is_prompt = model_input .prompt_lens is not None # prefill if True, otherwise decode
168+
169+ if is_prompt :
170+ input_position = 0
171+ # Currently only support same prompt length
172+ assert torch .all (model_input .prompt_lens == model_input .prompt_lens [0 ]), "Currently only supporting same prompt lengths for prefill"
173+ batch_size = model_input .prompt_lens .shape [0 ]
174+ else :
175+ # Currently only support same decode positions
176+ input_position = model_input .input_positions [0 ].item ()
177+ assert torch .all (model_input .input_positions == input_position ), "Currently only supporting same input positions for decode"
178+ batch_size = model_input .input_tokens .shape [0 ]
179+
180+ input_tokens = model_input .input_tokens .view (batch_size , - 1 )
181+
182+ execute_model_kwargs = {
183+ "tokens" : input_tokens ,
184+ "start_pos" : input_position ,
185+ # TODO: Add block table and maybe kv cache
186+ }
187+
188+ logits = self .model .forward (** execute_model_kwargs ) # [batch_size, seq_len, vocab_size]
189+
190+ # Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
191+ # Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
192+ next_logits = logits [:, - 1 , :] # batch, vocab of last token
193+ next_token_ids = self ._sample_tokens (next_logits )
194+
195+ # Minimal code to construct the sampler outputs, based on tpu_model_runner.py
196+ # TT backend does not support the advanced sampling parameters such as logprobs.
197+ zero_logprob = Logprob (0.0 )
198+ sampler_outputs = []
199+ for batch_idx , seq_ids in enumerate (model_input .seq_groups ):
200+ assert len (seq_ids ) == 1 # Only support one sequence per request group
201+ next_token_id = next_token_ids [batch_idx ]
202+ seq_outputs = [SequenceOutput (seq_ids [0 ], next_token_id ,
203+ {next_token_id : zero_logprob })]
204+ sampler_outputs .append (
205+ CompletionSequenceGroupOutput (seq_outputs , None ))
206+ return [SamplerOutput (sampler_outputs )]
207+
208+
209+ def _sample_tokens (self , logits ):
210+ # TODO: Add other sampling methods, currently only using greedy sampling
211+ return torch .argmax (logits , dim = - 1 )
0 commit comments