Skip to content

Commit

Permalink
add trim_attn_metadata comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Jul 1, 2024
1 parent 1ec95c4 commit 2394c41
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,34 @@ def _seq_len(self, attn_metadata):
return attn_metadata.block_tables.size(1) * self.block_size

def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
prefill_metadata = subtuple(metadata,
'TrimmedAttentionMetadata',
['block_tables',
'seq_lens_tensor',
'attn_bias',
'slot_mapping',
'is_prompt'])
return prefill_metadata
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.

# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)

# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata,
'TrimmedAttentionMetadata',
['block_tables',
'seq_lens_tensor',
'attn_bias',
'slot_mapping',
'is_prompt'])
return attention_metadata

@torch.inference_mode()
def execute_model(
Expand Down

0 comments on commit 2394c41

Please sign in to comment.