2020# limitations under the License.
2121"""Inference-only BaiChuan model compatible with HuggingFace weights."""
2222import math
23- from typing import Iterable , List , Optional , Set , Tuple , Union
23+ from typing import Iterable , Optional , Set , Tuple , Union
2424
2525import torch
2626from torch import nn
2727from transformers import PretrainedConfig
2828
29- from vllm .attention import Attention , AttentionMetadata
29+ from vllm .attention import Attention
3030from vllm .compilation .decorators import support_torch_compile
3131from vllm .config import CacheConfig , VllmConfig
3232from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
@@ -182,14 +182,12 @@ def forward(
182182 self ,
183183 positions : torch .Tensor ,
184184 hidden_states : torch .Tensor ,
185- kv_cache : torch .Tensor ,
186- attn_metadata : AttentionMetadata ,
187185 ) -> torch .Tensor :
188186 qkv , _ = self .W_pack (hidden_states )
189187 q , k , v = qkv .chunk (chunks = 3 , dim = - 1 )
190188 if self .postion_embedding != "ALIBI" :
191189 q , k = self .rotary_emb (positions , q , k )
192- attn_output = self .attn (q , k , v , kv_cache , attn_metadata )
190+ attn_output = self .attn (q , k , v )
193191 output , _ = self .o_proj (attn_output )
194192 return output
195193
@@ -232,8 +230,6 @@ def forward(
232230 self ,
233231 positions : torch .Tensor ,
234232 hidden_states : torch .Tensor ,
235- kv_cache : torch .Tensor ,
236- attn_metadata : AttentionMetadata ,
237233 residual : Optional [torch .Tensor ],
238234 ) -> Tuple [torch .Tensor , torch .Tensor ]:
239235 # Self Attention
@@ -246,8 +242,6 @@ def forward(
246242 hidden_states = self .self_attn (
247243 positions = positions ,
248244 hidden_states = hidden_states ,
249- kv_cache = kv_cache ,
250- attn_metadata = attn_metadata ,
251245 )
252246
253247 # Fully Connected
@@ -301,8 +295,6 @@ def forward(
301295 self ,
302296 input_ids : torch .Tensor ,
303297 positions : torch .Tensor ,
304- kv_caches : List [torch .Tensor ],
305- attn_metadata : AttentionMetadata ,
306298 intermediate_tensors : Optional [IntermediateTensors ],
307299 inputs_embeds : Optional [torch .Tensor ] = None ,
308300 ) -> Union [torch .Tensor , IntermediateTensors ]:
@@ -316,13 +308,10 @@ def forward(
316308 assert intermediate_tensors is not None
317309 hidden_states = intermediate_tensors ["hidden_states" ]
318310 residual = intermediate_tensors ["residual" ]
319- for i in range (self .start_layer , self .end_layer ):
320- layer = self .layers [i ]
311+ for layer in self .layers [self .start_layer :self .end_layer ]:
321312 hidden_states , residual = layer (
322313 positions ,
323314 hidden_states ,
324- kv_caches [i - self .start_layer ],
325- attn_metadata ,
326315 residual ,
327316 )
328317 if not get_pp_group ().is_last_rank :
@@ -379,13 +368,10 @@ def forward(
379368 self ,
380369 input_ids : torch .Tensor ,
381370 positions : torch .Tensor ,
382- kv_caches : List [torch .Tensor ],
383- attn_metadata : AttentionMetadata ,
384371 intermediate_tensors : Optional [IntermediateTensors ] = None ,
385372 inputs_embeds : Optional [torch .Tensor ] = None ,
386373 ) -> Union [torch .Tensor , IntermediateTensors ]:
387- hidden_states = self .model (input_ids , positions , kv_caches ,
388- attn_metadata , intermediate_tensors ,
374+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
389375 inputs_embeds )
390376 return hidden_states
391377
0 commit comments