Skip to content

Commit cdc1fa1

Browse files
authored
Remove unused kwargs from model definitions (#13555)
1 parent f61528d commit cdc1fa1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+436
-1654
lines changed

docs/source/contributing/model/basic.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def forward(
7474
self,
7575
input_ids: torch.Tensor,
7676
positions: torch.Tensor,
77-
kv_caches: List[torch.Tensor],
78-
attn_metadata: AttentionMetadata,
7977
) -> torch.Tensor:
8078
...
8179
```

docs/source/contributing/model/multimodal.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ Further update the model as follows:
1616
self,
1717
input_ids: torch.Tensor,
1818
positions: torch.Tensor,
19-
kv_caches: List[torch.Tensor],
20-
attn_metadata: AttentionMetadata,
2119
+ pixel_values: torch.Tensor,
2220
) -> SamplerOutput:
2321
```

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -644,11 +644,7 @@ def _run_encoder_attention_test(
644644
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
645645
reshaped_query = packed_qkv.query.view(
646646
-1, test_pt.num_heads * test_pt.head_size)
647-
return attn.forward(
648-
reshaped_query, packed_qkv.key, packed_qkv.value,
649-
torch.tensor([],
650-
dtype=torch.float32,
651-
device=packed_qkv.query.device), attn_metadata)
647+
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
652648

653649

654650
def _run_decoder_self_attention_test(
@@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
682678
& attn_metadata
683679
'''
684680
attn = test_rsrcs.attn
685-
kv_cache = test_rsrcs.kv_cache
686681
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
687682
assert packed_qkv is not None
688683
with set_forward_context(attn_metadata, vllm_config):
@@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
695690
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
696691
reshaped_query = packed_qkv.query.view(
697692
-1, test_pt.num_heads * test_pt.head_size)
698-
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
699-
kv_cache, attn_metadata)
693+
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
700694

701695

702696
def _run_encoder_decoder_cross_attention_test(
@@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
744738
assert decoder_test_params.packed_qkvo.packed_qkv is not None
745739

746740
attn = test_rsrcs.attn
747-
kv_cache = test_rsrcs.kv_cache
748741
if cross_test_params is None:
749742
key = None
750743
value = None
@@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
762755
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
763756
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
764757
-1, test_pt.num_heads * test_pt.head_size)
765-
return attn.forward(reshaped_query, key, value, kv_cache,
766-
attn_metadata)
758+
return attn.forward(reshaped_query, key, value)
767759

768760

769761
@pytest.fixture(autouse=True)

vllm/attention/layer.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88

99
import vllm.envs as envs
10-
from vllm.attention import AttentionMetadata, AttentionType
10+
from vllm.attention import AttentionType
1111
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
1212
from vllm.config import CacheConfig, get_current_vllm_config
1313
from vllm.forward_context import ForwardContext, get_forward_context
@@ -153,15 +153,10 @@ def forward(
153153
query: torch.Tensor,
154154
key: torch.Tensor,
155155
value: torch.Tensor,
156-
kv_cache: torch.Tensor,
157-
attn_metadata: AttentionMetadata,
158156
) -> torch.Tensor:
159-
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
160-
# directly, use `self.kv_cache` and
161-
# `get_forward_context().attn_metadata` instead.
162157
if self.calculate_kv_scales:
163-
ctx_attn_metadata = get_forward_context().attn_metadata
164-
if ctx_attn_metadata.enable_kv_scales_calculation:
158+
attn_metadata = get_forward_context().attn_metadata
159+
if attn_metadata.enable_kv_scales_calculation:
165160
self.calc_kv_scales(key, value)
166161
if self.use_output:
167162
output = torch.empty_like(query)
@@ -177,14 +172,14 @@ def forward(
177172
value = value.view(-1, self.num_kv_heads, self.head_size)
178173
if self.use_direct_call:
179174
forward_context: ForwardContext = get_forward_context()
180-
ctx_attn_metadata = forward_context.attn_metadata
175+
attn_metadata = forward_context.attn_metadata
181176
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
182177
self.impl.forward(self,
183178
query,
184179
key,
185180
value,
186181
self_kv_cache,
187-
ctx_attn_metadata,
182+
attn_metadata,
188183
output=output)
189184
else:
190185
torch.ops.vllm.unified_attention_with_output(
@@ -193,10 +188,10 @@ def forward(
193188
else:
194189
if self.use_direct_call:
195190
forward_context = get_forward_context()
196-
ctx_attn_metadata = forward_context.attn_metadata
191+
attn_metadata = forward_context.attn_metadata
197192
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
198193
return self.impl.forward(self, query, key, value,
199-
self_kv_cache, ctx_attn_metadata)
194+
self_kv_cache, attn_metadata)
200195
else:
201196
return torch.ops.vllm.unified_attention(
202197
query, key, value, self.layer_name)

vllm/model_executor/layers/mamba/mamba_mixer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.attention.backends.abstract import AttentionMetadata
88
from vllm.distributed.parallel_state import (
99
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
10+
from vllm.forward_context import get_forward_context
1011
from vllm.model_executor.custom_op import CustomOp
1112
from vllm.model_executor.layers.layernorm import RMSNorm
1213
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -130,14 +131,14 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
130131
) if use_rms_norm else None
131132

132133
def forward_native(self, hidden_states: torch.Tensor,
133-
attn_metadata: AttentionMetadata,
134134
conv_state: torch.Tensor, ssm_state: torch.Tensor):
135135
pass
136136

137137
def forward_cuda(self, hidden_states: torch.Tensor,
138-
attn_metadata: AttentionMetadata,
139138
mamba_cache_params: MambaCacheParams):
140139

140+
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
141+
141142
# 1. Gated MLP's linear projection
142143
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
143144
hidden_states, gate = projected_states.chunk(2, dim=-2)

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_tensor_model_parallel_world_size,
1515
tensor_model_parallel_all_gather,
1616
tensor_model_parallel_all_reduce)
17+
from vllm.forward_context import get_forward_context
1718
from vllm.model_executor.custom_op import CustomOp
1819
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1920
RowParallelLinear)
@@ -376,17 +377,16 @@ def __init__(self,
376377
eps=rms_norm_eps)
377378

378379
def forward_native(self, hidden_states: torch.Tensor,
379-
attn_metadata: AttentionMetadata,
380380
conv_state: torch.Tensor, ssm_state: torch.Tensor):
381381
pass
382382

383383
def forward_cuda(
384384
self,
385385
hidden_states: torch.Tensor,
386-
attn_metadata: AttentionMetadata,
387386
mamba_cache_params: MambaCacheParams,
388387
sequence_idx: Optional[torch.Tensor] = None,
389388
):
389+
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
390390

391391
seq_len, _ = hidden_states.shape
392392
groups_time_state_size = self.n_groups * self.ssm_state_size

vllm/model_executor/models/adapters.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
160160
return cls
161161

162162
# Lazy import
163-
from vllm.attention import AttentionMetadata
164163
from vllm.config import VllmConfig
165164
from vllm.model_executor.layers.linear import RowParallelLinear
166165
from vllm.model_executor.layers.pooler import PoolingType
@@ -201,13 +200,10 @@ def forward(
201200
self,
202201
input_ids: torch.Tensor,
203202
positions: torch.Tensor,
204-
kv_caches: list[torch.Tensor],
205-
attn_metadata: AttentionMetadata,
206203
intermediate_tensors: Optional[IntermediateTensors] = None,
207204
inputs_embeds: Optional[torch.Tensor] = None,
208205
) -> torch.Tensor:
209-
hidden_states = super().forward(input_ids, positions, kv_caches,
210-
attn_metadata,
206+
hidden_states = super().forward(input_ids, positions,
211207
intermediate_tensors,
212208
inputs_embeds)
213209
logits, _ = self.score(hidden_states)

vllm/model_executor/models/arctic.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import nn
77

8-
from vllm.attention import Attention, AttentionMetadata
8+
from vllm.attention import Attention
99
from vllm.compilation.decorators import support_torch_compile
1010
from vllm.config import CacheConfig, VllmConfig
1111
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@@ -283,13 +283,11 @@ def forward(
283283
self,
284284
positions: torch.Tensor,
285285
hidden_states: torch.Tensor,
286-
kv_cache: torch.Tensor,
287-
attn_metadata: AttentionMetadata,
288286
) -> torch.Tensor:
289287
qkv, _ = self.qkv_proj(hidden_states)
290288
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
291289
q, k = self.rotary_emb(positions, q, k)
292-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
290+
attn_output = self.attn(q, k, v)
293291
output, _ = self.o_proj(attn_output)
294292
return output
295293

@@ -336,16 +334,12 @@ def forward(
336334
self,
337335
positions: torch.Tensor,
338336
hidden_states: torch.Tensor,
339-
kv_cache: torch.Tensor,
340-
attn_metadata: AttentionMetadata,
341337
) -> torch.Tensor:
342338
residual_input = hidden_states
343339
hidden_states = self.input_layernorm(hidden_states)
344340
hidden_states = self.self_attn(
345341
positions=positions,
346342
hidden_states=hidden_states,
347-
kv_cache=kv_cache,
348-
attn_metadata=attn_metadata,
349343
)
350344
hidden_states = residual_input + hidden_states
351345

@@ -400,8 +394,6 @@ def forward(
400394
self,
401395
input_ids: torch.Tensor,
402396
positions: torch.Tensor,
403-
kv_caches: List[torch.Tensor],
404-
attn_metadata: AttentionMetadata,
405397
intermediate_tensors: Optional[IntermediateTensors],
406398
inputs_embeds: Optional[torch.Tensor] = None,
407399
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -413,11 +405,8 @@ def forward(
413405
else:
414406
assert intermediate_tensors is not None
415407
hidden_states = intermediate_tensors["hidden_states"]
416-
for i in range(self.start_layer, self.end_layer):
417-
layer = self.layers[i]
418-
hidden_states = layer(positions, hidden_states,
419-
kv_caches[i - self.start_layer],
420-
attn_metadata)
408+
for layer in self.layers[self.start_layer:self.end_layer]:
409+
hidden_states = layer(positions, hidden_states)
421410
if not get_pp_group().is_last_rank:
422411
return IntermediateTensors({"hidden_states": hidden_states})
423412
hidden_states = self.norm(hidden_states)
@@ -458,13 +447,10 @@ def forward(
458447
self,
459448
input_ids: torch.Tensor,
460449
positions: torch.Tensor,
461-
kv_caches: List[torch.Tensor],
462-
attn_metadata: AttentionMetadata,
463450
intermediate_tensors: Optional[IntermediateTensors] = None,
464451
inputs_embeds: Optional[torch.Tensor] = None,
465452
) -> Union[torch.Tensor, IntermediateTensors]:
466-
hidden_states = self.model(input_ids, positions, kv_caches,
467-
attn_metadata, intermediate_tensors,
453+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
468454
inputs_embeds)
469455
return hidden_states
470456

vllm/model_executor/models/aria.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from transformers.models.aria.modeling_aria import AriaCrossAttention
1010
from transformers.models.aria.processing_aria import AriaProcessor
1111

12-
from vllm.attention import AttentionMetadata
1312
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
1413
from vllm.distributed import get_tensor_model_parallel_rank
1514
from vllm.model_executor.layers.activation import get_act_fn
@@ -626,8 +625,6 @@ def forward(
626625
self,
627626
input_ids: torch.Tensor,
628627
positions: torch.Tensor,
629-
kv_caches: List[torch.Tensor],
630-
attn_metadata: AttentionMetadata,
631628
intermediate_tensors: Optional[IntermediateTensors] = None,
632629
inputs_embeds: Optional[torch.Tensor] = None,
633630
**kwargs: object,
@@ -643,8 +640,6 @@ def forward(
643640
hidden_states = self.language_model(
644641
input_ids,
645642
positions,
646-
kv_caches,
647-
attn_metadata,
648643
intermediate_tensors,
649644
inputs_embeds=inputs_embeds,
650645
)

vllm/model_executor/models/baichuan.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
# limitations under the License.
2121
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
2222
import math
23-
from typing import Iterable, List, Optional, Set, Tuple, Union
23+
from typing import Iterable, Optional, Set, Tuple, Union
2424

2525
import torch
2626
from torch import nn
2727
from transformers import PretrainedConfig
2828

29-
from vllm.attention import Attention, AttentionMetadata
29+
from vllm.attention import Attention
3030
from vllm.compilation.decorators import support_torch_compile
3131
from vllm.config import CacheConfig, VllmConfig
3232
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@@ -182,14 +182,12 @@ def forward(
182182
self,
183183
positions: torch.Tensor,
184184
hidden_states: torch.Tensor,
185-
kv_cache: torch.Tensor,
186-
attn_metadata: AttentionMetadata,
187185
) -> torch.Tensor:
188186
qkv, _ = self.W_pack(hidden_states)
189187
q, k, v = qkv.chunk(chunks=3, dim=-1)
190188
if self.postion_embedding != "ALIBI":
191189
q, k = self.rotary_emb(positions, q, k)
192-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
190+
attn_output = self.attn(q, k, v)
193191
output, _ = self.o_proj(attn_output)
194192
return output
195193

@@ -232,8 +230,6 @@ def forward(
232230
self,
233231
positions: torch.Tensor,
234232
hidden_states: torch.Tensor,
235-
kv_cache: torch.Tensor,
236-
attn_metadata: AttentionMetadata,
237233
residual: Optional[torch.Tensor],
238234
) -> Tuple[torch.Tensor, torch.Tensor]:
239235
# Self Attention
@@ -246,8 +242,6 @@ def forward(
246242
hidden_states = self.self_attn(
247243
positions=positions,
248244
hidden_states=hidden_states,
249-
kv_cache=kv_cache,
250-
attn_metadata=attn_metadata,
251245
)
252246

253247
# Fully Connected
@@ -301,8 +295,6 @@ def forward(
301295
self,
302296
input_ids: torch.Tensor,
303297
positions: torch.Tensor,
304-
kv_caches: List[torch.Tensor],
305-
attn_metadata: AttentionMetadata,
306298
intermediate_tensors: Optional[IntermediateTensors],
307299
inputs_embeds: Optional[torch.Tensor] = None,
308300
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -316,13 +308,10 @@ def forward(
316308
assert intermediate_tensors is not None
317309
hidden_states = intermediate_tensors["hidden_states"]
318310
residual = intermediate_tensors["residual"]
319-
for i in range(self.start_layer, self.end_layer):
320-
layer = self.layers[i]
311+
for layer in self.layers[self.start_layer:self.end_layer]:
321312
hidden_states, residual = layer(
322313
positions,
323314
hidden_states,
324-
kv_caches[i - self.start_layer],
325-
attn_metadata,
326315
residual,
327316
)
328317
if not get_pp_group().is_last_rank:
@@ -379,13 +368,10 @@ def forward(
379368
self,
380369
input_ids: torch.Tensor,
381370
positions: torch.Tensor,
382-
kv_caches: List[torch.Tensor],
383-
attn_metadata: AttentionMetadata,
384371
intermediate_tensors: Optional[IntermediateTensors] = None,
385372
inputs_embeds: Optional[torch.Tensor] = None,
386373
) -> Union[torch.Tensor, IntermediateTensors]:
387-
hidden_states = self.model(input_ids, positions, kv_caches,
388-
attn_metadata, intermediate_tensors,
374+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
389375
inputs_embeds)
390376
return hidden_states
391377

0 commit comments

Comments
 (0)