|
16 | 16 | tensor_model_parallel_all_gather, |
17 | 17 | tensor_model_parallel_all_reduce) |
18 | 18 | from vllm.logger import init_logger |
| 19 | +from vllm.model_executor.custom_op import CustomOp |
19 | 20 | from vllm.model_executor.layers.quantization.base_config import ( |
20 | 21 | QuantizationConfig, QuantizeMethodBase) |
21 | 22 | from vllm.model_executor.layers.utils import dispatch_unquantized_gemm |
@@ -226,7 +227,7 @@ def apply(self, |
226 | 227 | return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) |
227 | 228 |
|
228 | 229 |
|
229 | | -class LinearBase(torch.nn.Module): |
| 230 | +class LinearBase(CustomOp): |
230 | 231 | """Base linear layer. |
231 | 232 |
|
232 | 233 | Args: |
@@ -269,12 +270,8 @@ def __init__( |
269 | 270 | prefix=prefix) |
270 | 271 | self.return_bias = return_bias |
271 | 272 |
|
272 | | - def forward( |
273 | | - self, x: torch.Tensor |
274 | | - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
275 | | - raise NotImplementedError |
276 | | - |
277 | 273 |
|
| 274 | +@CustomOp.register("replicated_linear") |
278 | 275 | class ReplicatedLinear(LinearBase): |
279 | 276 | """Replicated linear layer. |
280 | 277 |
|
@@ -443,6 +440,7 @@ def weight_loader(self, |
443 | 440 | param[shard_offset:shard_offset + shard_size] = loaded_weight |
444 | 441 |
|
445 | 442 |
|
| 443 | +@CustomOp.register("column_parallel_linear") |
446 | 444 | class ColumnParallelLinear(LinearBase): |
447 | 445 | """Linear layer with column parallelism. |
448 | 446 |
|
@@ -1229,6 +1227,7 @@ def weight_loader(self, |
1229 | 1227 | param_data.copy_(loaded_weight) |
1230 | 1228 |
|
1231 | 1229 |
|
| 1230 | +@CustomOp.register("row_parallel_linear") |
1232 | 1231 | class RowParallelLinear(LinearBase): |
1233 | 1232 | """Linear layer with row parallelism. |
1234 | 1233 |
|
@@ -1405,6 +1404,7 @@ def extra_repr(self) -> str: |
1405 | 1404 | return s |
1406 | 1405 |
|
1407 | 1406 |
|
| 1407 | +@CustomOp.register("qkv_cross_parallel_linear") |
1408 | 1408 | class QKVCrossParallelLinear(LinearBase): |
1409 | 1409 | """Linear layers for efficient cross-attention's QKV transformation. |
1410 | 1410 |
|
|
0 commit comments