|
12 | 12 | from vllm.lora.layers import (ColumnParallelLinearWithLoRA, |
13 | 13 | MergedColumnParallelLinearWithLoRA, |
14 | 14 | MergedQKVParallelLinearWithLora, |
| 15 | + QKVParallelLinearWithLora, |
15 | 16 | RowParallelLinearWithLoRA) |
16 | 17 | from vllm.lora.punica import bgmv, dispatch_bgmv_low_level |
17 | 18 |
|
@@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module, |
90 | 91 | def _mcp_apply(x, bias, layer): |
91 | 92 | """ |
92 | 93 | MergedColumnParallelLinearWithShardedLoRA and |
93 | | - QKVParallelLinearWithShardedLora share the same |
| 94 | + MergedQKVParallelLinearWithShardedLora share the same |
94 | 95 | LoRa weight application method. |
95 | 96 | |
96 | 97 | The main difference is the step by shard_size for lora_b which can |
97 | | - vary for QKVParallelLinearWithShardedLora but is constant for |
| 98 | + vary for MergedQKVParallelLinearWithShardedLora but is constant for |
98 | 99 | MergedColumnParallelLinearWithShardedLoRA. |
99 | 100 | """ |
100 | 101 | # expecting 2 for column parallel and 3 for qkv |
@@ -167,14 +168,65 @@ def can_replace_layer(cls, source_layer: nn.Module, |
167 | 168 | ) |
168 | 169 |
|
169 | 170 |
|
170 | | -class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): |
| 171 | +class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): |
171 | 172 | """ |
172 | 173 | Differs from QKVParallelLinearWithLora by slicing the |
173 | 174 | LoRA A's also. |
174 | 175 |
|
175 | 176 | Based on S-LoRA, slicing happens along the rank dim. |
176 | 177 | """ |
177 | 178 |
|
| 179 | + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: |
| 180 | + tp_rank = get_tensor_model_parallel_rank() |
| 181 | + shard_size = self.lora_a_stacked.shape[2] |
| 182 | + start_idx = tp_rank * shard_size |
| 183 | + lora_a = lora_a[:, start_idx:start_idx + shard_size] |
| 184 | + return lora_a |
| 185 | + |
| 186 | + def apply(self, x: torch.Tensor, |
| 187 | + bias: Optional[torch.Tensor]) -> torch.Tensor: |
| 188 | + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) |
| 189 | + |
| 190 | + x = x.view(-1, x.shape[-1]) |
| 191 | + output, out_orig_shape = output.view(-1, |
| 192 | + output.shape[-1]), output.shape |
| 193 | + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), |
| 194 | + dtype=torch.float32, |
| 195 | + device=x.device) |
| 196 | + |
| 197 | + bgmv(buffer, x, self.lora_a_stacked, |
| 198 | + self.indices[:self.indices_len[0]], 0, 1.0) |
| 199 | + buffer = tensor_model_parallel_all_gather(buffer) |
| 200 | + bgmv(output, buffer, self.lora_b_stacked, |
| 201 | + self.indices[:self.indices_len[0]], 0, 1.0) |
| 202 | + # now have column partitioned output |
| 203 | + |
| 204 | + output = output.view(*out_orig_shape) |
| 205 | + return output |
| 206 | + |
| 207 | + @classmethod |
| 208 | + @_fully_sharded_can_replace |
| 209 | + def can_replace_layer(cls, source_layer: nn.Module, |
| 210 | + lora_config: LoRAConfig, packed_modules_list: List, |
| 211 | + model_config: Optional[PretrainedConfig]) -> bool: |
| 212 | + # specifying kwargs so they can be easily accessed in decorator |
| 213 | + return super().can_replace_layer( |
| 214 | + source_layer=source_layer, |
| 215 | + lora_config=lora_config, |
| 216 | + packed_modules_list=packed_modules_list, |
| 217 | + model_config=model_config, |
| 218 | + decorate=False, |
| 219 | + ) |
| 220 | + |
| 221 | + |
| 222 | +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): |
| 223 | + """ |
| 224 | + Differs from MergedQKVParallelLinearWithLora by slicing the |
| 225 | + LoRA A's also. |
| 226 | +
|
| 227 | + Based on S-LoRA, slicing happens along the rank dim. |
| 228 | + """ |
| 229 | + |
178 | 230 | def slice_lora_a( |
179 | 231 | self, lora_a: List[Union[torch.Tensor, None]] |
180 | 232 | ) -> List[Union[torch.Tensor, None]]: |
|
0 commit comments