Skip to content

Commit 24e6ad3

Browse files
authored
[V1] Remove num_input_tokens from attn_metadata (#17193)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 2ef5d10 commit 24e6ad3

File tree

6 files changed

+14
-21
lines changed

6 files changed

+14
-21
lines changed

vllm/forward_context.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
7474
if vllm_config.parallel_config.data_parallel_size > 1:
7575
dp_size = vllm_config.parallel_config.data_parallel_size
7676
dp_rank = vllm_config.parallel_config.data_parallel_rank
77-
if attn_metadata is not None:
78-
if hasattr(attn_metadata, "num_prefill_tokens"):
79-
# for v0 attention backends
80-
batchsize = attn_metadata.num_prefill_tokens + \
81-
attn_metadata.num_decode_tokens
82-
else:
83-
# for v1 attention backends
84-
batchsize = attn_metadata.num_input_tokens
77+
if attn_metadata is not None and hasattr(attn_metadata,
78+
"num_prefill_tokens"):
79+
# for v0 attention backends
80+
batchsize = attn_metadata.num_prefill_tokens + \
81+
attn_metadata.num_decode_tokens
8582
else:
83+
# for v1 attention backends or no attn_metadata
8684
batchsize = num_tokens
8785
num_tokens_across_dp = [0] * dp_size
8886
num_tokens_across_dp[dp_rank] = batchsize
@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
124122
attn_metadata.num_decode_tokens
125123
else:
126124
# for v1 attention backends
127-
batchsize = attn_metadata.num_input_tokens
125+
batchsize = num_tokens
128126
# we use synchronous scheduling right now,
129127
# adding a sync point here should not affect
130128
# scheduling of the next batch

vllm/v1/attention/backends/flash_attn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ class FlashAttentionMetadata:
9494
scheduler_metadata: Optional[torch.Tensor] = None
9595
prefix_scheduler_metadata: Optional[torch.Tensor] = None
9696

97-
# For logging.
98-
num_input_tokens: int = 0 # Number of tokens including padding.
99-
10097
# for local attention
10198
@dataclass
10299
class LocalAttentionMetadata:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@ class FlashInferMetadata:
183183
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
184184
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
185185

186-
# For logging.
187-
num_input_tokens: int = 0 # Number of tokens including padding.
188-
189186
@property
190187
def query_start_loc(self):
191188
# The GPUModelRunner expects to be able to access this property.

vllm/v1/attention/backends/mla/common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
312312
num_decode_tokens: int
313313
num_prefills: int
314314

315-
# For logging.
316-
num_input_tokens: int = 0 # Number of tokens including padding.
317-
318315
# The dimension of the attention heads
319316
head_dim: Optional[int] = None
320317

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,6 @@ def execute_model(
10361036
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
10371037
else:
10381038
num_input_tokens = num_scheduled_tokens
1039-
attn_metadata.num_input_tokens = num_input_tokens
10401039

10411040
# _prepare_inputs may reorder the batch, so we must gather multi
10421041
# modal outputs after that to ensure the correct order
@@ -1088,7 +1087,9 @@ def execute_model(
10881087

10891088
# Run the decoder.
10901089
# Use persistent buffers for CUDA graphs.
1091-
with set_forward_context(attn_metadata, self.vllm_config):
1090+
with set_forward_context(attn_metadata,
1091+
self.vllm_config,
1092+
num_tokens=num_input_tokens):
10921093
output = self.model(
10931094
input_ids=input_ids,
10941095
positions=positions,

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,10 @@ def execute_model(
769769
xm.mark_step()
770770
num_reqs = self.input_batch.num_reqs
771771
# Run the decoder
772-
with set_forward_context(attn_metadata, self.vllm_config):
772+
with set_forward_context(
773+
attn_metadata,
774+
self.vllm_config,
775+
num_tokens=scheduler_output.total_num_scheduled_tokens):
773776
hidden_states = self.model(
774777
input_ids=input_ids,
775778
positions=self.position_ids,

0 commit comments

Comments
 (0)