11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- from typing import Optional , Union , cast
4+ from typing import Optional , Union
55
66import torch
77import torch .nn as nn
@@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
3232 == len (layer .lora_b_stacked )
3333 == len (layer .output_slices )
3434 )
35- if layer .lora_bias_stacked is not None :
36- assert layer .n_slices == len (layer .lora_bias_stacked )
3735
3836 output = layer .base_layer .quant_method .apply (layer .base_layer , x , bias )
3937
@@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
6159 output ,
6260 buffers ,
6361 layer .lora_b_stacked ,
64- layer .lora_bias_stacked ,
6562 layer .output_slices ,
6663 offset_start = 0 ,
6764 add_input = True ,
@@ -122,16 +119,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
122119 lora_b = lora_b [start_idx :end_idx , :]
123120 return lora_b
124121
125- def slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
126- # TODO: Fix the slicing logic of bias.
127- if bias is None :
128- return bias
129- shard_size = self .output_size
130- start_idx = self .tp_rank * shard_size
131- end_idx = (self .tp_rank + 1 ) * shard_size
132- bias = bias [start_idx :end_idx ]
133- return bias
134-
135122 def forward (
136123 self , input_ : torch .Tensor
137124 ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [torch .Tensor ]]]:
@@ -238,17 +225,6 @@ def create_lora_weights(
238225 )
239226 for output_size in self .output_slices
240227 )
241- if lora_config .bias_enabled :
242- self .lora_bias_stacked = tuple (
243- torch .zeros (
244- max_loras ,
245- 1 ,
246- output_size ,
247- dtype = lora_config .lora_dtype ,
248- device = self .device ,
249- )
250- for output_size in self .output_slices
251- )
252228
253229 def slice_lora_a (
254230 self , lora_a : list [Union [torch .Tensor , None ]]
@@ -268,31 +244,18 @@ def slice_lora_b(
268244 ]
269245 return sliced_lora_b
270246
271- def slice_bias (
272- self , bias : list [Union [torch .Tensor , None ]]
273- ) -> list [Union [torch .Tensor , None ]]:
274- for i , (shard_id , shard_size ) in enumerate (
275- zip (self .output_ids , self .output_slices )
276- ):
277- if (bias_i := bias [i ]) is not None :
278- bias [i ] = bias_i [shard_size * shard_id : shard_size * (shard_id + 1 )]
279- return bias
280-
281247 def set_lora (
282248 self ,
283249 index : int ,
284250 lora_a : torch .Tensor ,
285251 lora_b : torch .Tensor ,
286252 embeddings_tensor : Optional [torch .Tensor ],
287- lora_bias : Optional [torch .Tensor ] = None ,
288253 ):
289254 self .reset_lora (index )
290255
291256 if self .tp_size > 1 :
292257 lora_a = self .slice_lora_a (lora_a )
293258 lora_b = self .slice_lora_b (lora_b )
294- if lora_bias is not None :
295- lora_bias = self .slice_bias (lora_bias )
296259
297260 for i in range (self .n_slices ):
298261 if (lora_a_i := lora_a [i ]) is not None :
@@ -304,16 +267,6 @@ def set_lora(
304267 index , 0 , : lora_b_i .shape [0 ], : lora_b_i .shape [1 ]
305268 ].copy_ (lora_b_i , non_blocking = True )
306269
307- if lora_bias is not None :
308- self .lora_bias_stacked = cast (
309- tuple [torch .Tensor , ...], self .lora_bias_stacked
310- )
311- for i in range (self .n_slices ):
312- if (lora_bias_i := lora_bias [i ]) is not None :
313- self .lora_bias_stacked [i ][index , 0 , : lora_bias_i .shape [0 ]].copy_ (
314- lora_bias_i , non_blocking = True
315- )
316-
317270 @classmethod
318271 @_not_fully_sharded_can_replace
319272 def can_replace_layer (
@@ -380,24 +333,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
380333 lora_b = torch .cat ([lora_b_q , lora_b_k , lora_b_v ], dim = 0 )
381334 return lora_b
382335
383- def slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
384- bias_q = bias [
385- self .q_proj_shard_size * self .q_shard_id : self .q_proj_shard_size
386- * (self .q_shard_id + 1 )
387- ]
388- k_offset = self .q_proj_total_size
389- bias_k = bias [
390- k_offset + self .kv_proj_shard_size * self .kv_shard_id : k_offset
391- + self .kv_proj_shard_size * (self .kv_shard_id + 1 )
392- ]
393- v_offset = k_offset + self .kv_proj_total_size
394- bias_v = bias [
395- v_offset + self .kv_proj_shard_size * self .kv_shard_id : v_offset
396- + self .kv_proj_shard_size * (self .kv_shard_id + 1 )
397- ]
398- bias = torch .cat ([bias_q , bias_k , bias_v ], dim = 1 )
399- return bias
400-
401336 @classmethod
402337 @_not_fully_sharded_can_replace
403338 def can_replace_layer (
0 commit comments