|
21 | 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22 | 22 | # See the License for the specific language governing permissions and
|
23 | 23 | # limitations under the License.
|
24 |
| -"""Inference-only LLaMA model compatible with HuggingFace weights.""" |
| 24 | +"""Inference-only SwissAI model compatible with HuggingFace weights.""" |
25 | 25 | from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
26 | 26 |
|
27 | 27 | import torch
|
28 | 28 | from torch import nn
|
29 |
| -from transformers import LlamaConfig |
| 29 | +from transformers import SwissAIConfig |
30 | 30 |
|
31 | 31 | from vllm.attention import Attention
|
32 | 32 | from vllm.compilation.decorators import support_torch_compile
|
33 | 33 | from vllm.config import CacheConfig, VllmConfig
|
34 | 34 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
35 | 35 | from vllm.model_executor.layers.activation import XIELU
|
36 | 36 | from vllm.model_executor.layers.layernorm import RMSNorm
|
37 |
| -from vllm.model_executor.layers.linear import (QKVParallelLinear, |
| 37 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 38 | + QKVParallelLinear, |
38 | 39 | RowParallelLinear)
|
39 | 40 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
40 | 41 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
@@ -66,7 +67,7 @@ def __init__(
|
66 | 67 | prefix: str = "",
|
67 | 68 | ) -> None:
|
68 | 69 | super().__init__()
|
69 |
| - self.up_proj = RowParallelLinear( |
| 70 | + self.up_proj = ColumnParallelLinear( |
70 | 71 | input_size=hidden_size,
|
71 | 72 | output_size=intermediate_size,
|
72 | 73 | bias=bias,
|
@@ -95,7 +96,7 @@ def forward(self, x):
|
95 | 96 | class SwissAIAttention(nn.Module):
|
96 | 97 |
|
97 | 98 | def __init__(self,
|
98 |
| - config: LlamaConfig, |
| 99 | + config: SwissAIConfig, |
99 | 100 | hidden_size: int,
|
100 | 101 | num_heads: int,
|
101 | 102 | num_kv_heads: int,
|
@@ -216,7 +217,7 @@ class SwissAIDecoderLayer(nn.Module):
|
216 | 217 |
|
217 | 218 | def __init__(
|
218 | 219 | self,
|
219 |
| - config: LlamaConfig, |
| 220 | + config: SwissAIConfig, |
220 | 221 | cache_config: Optional[CacheConfig] = None,
|
221 | 222 | quant_config: Optional[QuantizationConfig] = None,
|
222 | 223 | prefix: str = "",
|
|
0 commit comments