@@ -92,6 +92,7 @@ def __init__(
9292 device_config : DeviceConfig ,
9393 cache_config : CacheConfig ,
9494 load_config : LoadConfig ,
95+ trace_mode : bool = True ,
9596 ):
9697 self .model_config = model_config
9798 self .parallel_config = parallel_config
@@ -105,6 +106,9 @@ def __init__(
105106 self .sliding_window = model_config .get_sliding_window ()
106107 self .block_size = cache_config .block_size
107108
109+ self .trace_mode = trace_mode # whether to use ttnn tracing for model execution
110+ self .execute_trace_kwargs = None # kw args for trace execution (populated during first decode execution)
111+
108112 def load_model (self ) -> None :
109113 # Note: using custom TT loader instead of selecting from default vllm loaders
110114 loader = TTModelLoader (self .load_config )
@@ -234,6 +238,13 @@ def prepare_model_input(
234238 block_tables ,
235239 torch .zeros (batch_pad_len , block_tables .shape [1 ], dtype = torch .int32 , device = "cpu" )
236240 ])
241+
242+ # Pad block_tables to max num blocks so ttnn tracing can work (requires constant shape)
243+ if self .trace_mode :
244+ block_tables = torch .cat ([
245+ block_tables ,
246+ torch .zeros (block_tables .shape [0 ], self .cache_config .num_gpu_blocks - block_tables .shape [1 ], dtype = torch .int32 , device = "cpu" )
247+ ], dim = 1 )
237248
238249 return TTModelInput (input_tokens , input_positions , prompt_lens , seq_groups , block_tables , unpadded_batch_size , tt_sampling_params )
239250
@@ -257,7 +268,35 @@ def execute_model(
257268 "prompt_lens" : model_input .prompt_lens ,
258269 }
259270
260- logits = self .model .forward (** execute_model_kwargs ) # [batch_size, seq_len, vocab_size]
271+ is_decode = model_input .prompt_lens is None
272+
273+ if self .trace_mode and is_decode : # Trace mode for decode
274+ # Remove prompt_lens from execute_model_kwargs since it's not used for decode
275+ execute_model_kwargs .pop ("prompt_lens" )
276+
277+ # Capture trace for the first decode execution
278+ if self .execute_trace_kwargs is None :
279+ logger .info ("Capturing trace for first decode execution" )
280+ trace_id , tt_inp , rot_mat , cache_idxs_tt , tt_logits , tt_page_table = self .model .capture_trace (
281+ ** execute_model_kwargs
282+ )
283+ self .execute_trace_kwargs = {
284+ "trace_id" : trace_id ,
285+ "tt_inp" : tt_inp ,
286+ "rot_mat" : rot_mat ,
287+ "cache_idxs_tt" : cache_idxs_tt ,
288+ "tt_logits" : tt_logits ,
289+ "tt_page_table" : tt_page_table ,
290+ }
291+
292+ # Remove kv_cache from execute_model_kwargs since it doesn't need to be copied to device for trace execution
293+ execute_model_kwargs .pop ("kv_cache" )
294+
295+ logits = self .model .decode_forward_trace (
296+ ** execute_model_kwargs , ** self .execute_trace_kwargs
297+ )
298+ else :
299+ logits = self .model .forward (** execute_model_kwargs ) # [batch_size, seq_len, vocab_size]
261300
262301 # Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
263302 # Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
@@ -292,4 +331,13 @@ def _validate_sampling_params(self, sampling_params):
292331 assert sampling_params .best_of == 1 , "Currently only supporting best_of=1"
293332 assert not sampling_params .use_beam_search , "Currently not supporting beam search"
294333 assert sampling_params .logprobs is None , "Currently not supporting logprobs"
295- assert sampling_params .prompt_logprobs is None , "Currently not supporting prompt_logprobs"
334+ assert sampling_params .prompt_logprobs is None , "Currently not supporting prompt_logprobs"
335+
336+ ## Destructor (used to delete ttnn trace if using trace mode)
337+
338+ def __del__ (self ):
339+ if self .trace_mode and self .execute_trace_kwargs is not None :
340+ self .model .delete_trace (self .execute_trace_kwargs ["trace_id" ])
341+
342+ if hasattr (super (TTModelRunner , self ), '__del__' ):
343+ super ().__del__ ()
0 commit comments