From 2394c41b9d03b80fe43534aeca2b66408ea78e02 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 1 Jul 2024 18:57:41 +0300 Subject: [PATCH] add trim_attn_metadata comment --- vllm/worker/habana_model_runner.py | 36 +++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index e53350ecfc1fc..49f66ae1e0863 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -803,14 +803,34 @@ def _seq_len(self, attn_metadata): return attn_metadata.block_tables.size(1) * self.block_size def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: - prefill_metadata = subtuple(metadata, - 'TrimmedAttentionMetadata', - ['block_tables', - 'seq_lens_tensor', - 'attn_bias', - 'slot_mapping', - 'is_prompt']) - return prefill_metadata + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple(metadata, + 'TrimmedAttentionMetadata', + ['block_tables', + 'seq_lens_tensor', + 'attn_bias', + 'slot_mapping', + 'is_prompt']) + return attention_metadata @torch.inference_mode() def execute_model(