Skip to content

Commit 18f4c6e

Browse files
committed
vllm-project#14: Add trace_mode option to TTWorker and TTModelRunner, update perf measurement to decode multiple tokens
Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
1 parent 678885a commit 18f4c6e

File tree

4 files changed

+74
-10
lines changed

4 files changed

+74
-10
lines changed

examples/offline_inference_tt.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def run_inference_perf(
6464
print("Measuring performance with dummy prompts of length", input_prompt_len)
6565
prompt_token_ids = [[0]*input_prompt_len]*max_seqs_in_batch # dummy prompts
6666
sampling_params = sampling_params[:max_seqs_in_batch] if isinstance(sampling_params, list) else sampling_params
67-
sampling_params.max_tokens = 2 # 1 prefill output token + 1 decode output token
67+
68+
# Set an arbitrary max_tokens to simulate generating multiple tokens consecutively
69+
sampling_params.max_tokens = 33 # 1 prefill output token + 32 decode output tokens
6870

6971
# Compile run
7072
print("Starting compile run")
@@ -74,8 +76,8 @@ def run_inference_perf(
7476

7577
# Inference runs
7678
print("Starting inference runs")
77-
N_warmup = 5
78-
N_inference = 15
79+
N_warmup = 1
80+
N_inference = 5
7981
for i in tqdm(range(N_inference), desc="Inference runs"):
8082
if i == N_warmup: # Reset stats after warmup
8183
llm.llm_engine.stat_loggers['global'].reset()
@@ -105,7 +107,13 @@ def generate_tokens(llm : LLM, prompts, sampling_params, prompt_token_ids=None,
105107
parser = argparse.ArgumentParser()
106108
parser.add_argument("--prompts_json", type=str, default="tt_metal/prompts.json", help="Path to JSON file containing prompts")
107109
parser.add_argument("--measure_perf", action="store_true", help="Measure performance")
108-
parser.add_argument("--perf_prompt_len", type=int, default=127, help="Length of dummy prompts for performance measurement")
110+
parser.add_argument("--perf_prompt_len", type=int, default=128, help="Length of dummy prompts for performance measurement")
111+
parser.add_argument("--greedy_sampling", action="store_true", help="Use greedy decoding instead of top-k/p")
109112
args = parser.parse_args()
110113

111-
run_inference(args.prompts_json, measure_perf=args.measure_perf, perf_prompt_len=args.perf_prompt_len)
114+
run_inference(
115+
args.prompts_json,
116+
measure_perf=args.measure_perf,
117+
perf_prompt_len=args.perf_prompt_len,
118+
greedy_sampling=args.greedy_sampling
119+
)

tt_metal/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## vLLM and tt-metal Branches
33
Git-checkout the following branches in each repo separately:
44
- vLLM branch: [dev](https://github.com/tenstorrent/vllm/tree/dev) (last verified commit: [3f7beb2](https://github.com/tenstorrent/vllm/tree/3f7beb23cbaf3be2e104061905da5f91644e5a68))
5-
- tt-metal branch: [main](https://github.com/tenstorrent/tt-metal) (last verified commit: [f521af0](https://github.com/tenstorrent/tt-metal/tree/f521af0061bf53567942b7a27fd89aa300ec16ce))
5+
- tt-metal branch: [main](https://github.com/tenstorrent/tt-metal) (last verified commit: [f0b2483](https://github.com/tenstorrent/tt-metal/tree/f0b2483529a55d1101eb142ae1c70eec5260ecf7))
66

77
## Environment Creation
88

vllm/worker/tt_model_runner.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
device_config: DeviceConfig,
9393
cache_config: CacheConfig,
9494
load_config: LoadConfig,
95+
trace_mode: bool = True,
9596
):
9697
self.model_config = model_config
9798
self.parallel_config = parallel_config
@@ -105,6 +106,9 @@ def __init__(
105106
self.sliding_window = model_config.get_sliding_window()
106107
self.block_size = cache_config.block_size
107108

109+
self.trace_mode = trace_mode # whether to use ttnn tracing for model execution
110+
self.execute_trace_kwargs = None # kw args for trace execution (populated during first decode execution)
111+
108112
def load_model(self) -> None:
109113
# Note: using custom TT loader instead of selecting from default vllm loaders
110114
loader = TTModelLoader(self.load_config)
@@ -234,6 +238,13 @@ def prepare_model_input(
234238
block_tables,
235239
torch.zeros(batch_pad_len, block_tables.shape[1], dtype=torch.int32, device="cpu")
236240
])
241+
242+
# Pad block_tables to max num blocks so ttnn tracing can work (requires constant shape)
243+
if self.trace_mode:
244+
block_tables = torch.cat([
245+
block_tables,
246+
torch.zeros(block_tables.shape[0], self.cache_config.num_gpu_blocks - block_tables.shape[1], dtype=torch.int32, device="cpu")
247+
], dim=1)
237248

238249
return TTModelInput(input_tokens, input_positions, prompt_lens, seq_groups, block_tables, unpadded_batch_size, tt_sampling_params)
239250

@@ -257,7 +268,35 @@ def execute_model(
257268
"prompt_lens": model_input.prompt_lens,
258269
}
259270

260-
logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size]
271+
is_decode = model_input.prompt_lens is None
272+
273+
if self.trace_mode and is_decode: # Trace mode for decode
274+
# Remove prompt_lens from execute_model_kwargs since it's not used for decode
275+
execute_model_kwargs.pop("prompt_lens")
276+
277+
# Capture trace for the first decode execution
278+
if self.execute_trace_kwargs is None:
279+
logger.info("Capturing trace for first decode execution")
280+
trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits, tt_page_table = self.model.capture_trace(
281+
**execute_model_kwargs
282+
)
283+
self.execute_trace_kwargs = {
284+
"trace_id": trace_id,
285+
"tt_inp": tt_inp,
286+
"rot_mat": rot_mat,
287+
"cache_idxs_tt": cache_idxs_tt,
288+
"tt_logits": tt_logits,
289+
"tt_page_table": tt_page_table,
290+
}
291+
292+
# Remove kv_cache from execute_model_kwargs since it doesn't need to be copied to device for trace execution
293+
execute_model_kwargs.pop("kv_cache")
294+
295+
logits = self.model.decode_forward_trace(
296+
**execute_model_kwargs, **self.execute_trace_kwargs
297+
)
298+
else:
299+
logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size]
261300

262301
# Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
263302
# Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
@@ -292,4 +331,13 @@ def _validate_sampling_params(self, sampling_params):
292331
assert sampling_params.best_of == 1, "Currently only supporting best_of=1"
293332
assert not sampling_params.use_beam_search, "Currently not supporting beam search"
294333
assert sampling_params.logprobs is None, "Currently not supporting logprobs"
295-
assert sampling_params.prompt_logprobs is None, "Currently not supporting prompt_logprobs"
334+
assert sampling_params.prompt_logprobs is None, "Currently not supporting prompt_logprobs"
335+
336+
## Destructor (used to delete ttnn trace if using trace mode)
337+
338+
def __del__(self):
339+
if self.trace_mode and self.execute_trace_kwargs is not None:
340+
self.model.delete_trace(self.execute_trace_kwargs["trace_id"])
341+
342+
if hasattr(super(TTModelRunner, self), '__del__'):
343+
super().__del__()

vllm/worker/tt_worker.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,16 @@ def __init__(
181181
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
182182
self.cache_config.cache_dtype]
183183

184+
self.trace_mode = True # whether to use ttnn tracing for model execution, TODO: make this configurable
185+
184186
self.model_runner: TTModelRunner = TTModelRunner(
185187
model_config,
186188
parallel_config,
187189
scheduler_config,
188190
device_config,
189191
cache_config,
190-
load_config
192+
load_config,
193+
trace_mode=self.trace_mode,
191194
)
192195

193196
self.cache_engine: List[TTCacheEngine]
@@ -371,7 +374,10 @@ def _get_dispatch_core_type(self):
371374
return dispatch_core_type
372375

373376
def _open_t3k_mesh_device(self):
374-
device_params = {}
377+
if self.trace_mode:
378+
device_params = {"trace_region_size": 14227456} # TODO: make this configurable
379+
else:
380+
device_params = {}
375381
mesh_device = ttnn.open_mesh_device(
376382
ttnn.MeshShape(2, 4),
377383
dispatch_core_type=self._get_dispatch_core_type(),
@@ -398,6 +404,8 @@ def _enable_async_mode(self):
398404
## Destructor (used to close devices)
399405

400406
def __del__(self):
407+
del self.model_runner # Delete model runner first in case there are model arifacts (e.g ttnn trace)
408+
401409
if self.mesh_device:
402410
devices = self.mesh_device.get_devices()
403411

0 commit comments

Comments
 (0)