Skip to content

Commit f03bedf

Browse files
SwissAI parallelization bugfix (vllm-project#3)
* Update swissai.py Replaced LlamaConfig with SwissAIConfig Changed up_proj from RowParallelLinear to ColumnParallelLinear * `ColumnParallelLinear` import * `LLaMa` -> `SwissAI` --------- Co-authored-by: EduardDurech <39579228+EduardDurech@users.noreply.github.com>
1 parent 88fc1a5 commit f03bedf

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

vllm/model_executor/models/swissai.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@
2121
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
24-
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24+
"""Inference-only SwissAI model compatible with HuggingFace weights."""
2525
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
2626

2727
import torch
2828
from torch import nn
29-
from transformers import LlamaConfig
29+
from transformers import SwissAIConfig
3030

3131
from vllm.attention import Attention
3232
from vllm.compilation.decorators import support_torch_compile
3333
from vllm.config import CacheConfig, VllmConfig
3434
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3535
from vllm.model_executor.layers.activation import XIELU
3636
from vllm.model_executor.layers.layernorm import RMSNorm
37-
from vllm.model_executor.layers.linear import (QKVParallelLinear,
37+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
38+
QKVParallelLinear,
3839
RowParallelLinear)
3940
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4041
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -66,7 +67,7 @@ def __init__(
6667
prefix: str = "",
6768
) -> None:
6869
super().__init__()
69-
self.up_proj = RowParallelLinear(
70+
self.up_proj = ColumnParallelLinear(
7071
input_size=hidden_size,
7172
output_size=intermediate_size,
7273
bias=bias,
@@ -95,7 +96,7 @@ def forward(self, x):
9596
class SwissAIAttention(nn.Module):
9697

9798
def __init__(self,
98-
config: LlamaConfig,
99+
config: SwissAIConfig,
99100
hidden_size: int,
100101
num_heads: int,
101102
num_kv_heads: int,
@@ -216,7 +217,7 @@ class SwissAIDecoderLayer(nn.Module):
216217

217218
def __init__(
218219
self,
219-
config: LlamaConfig,
220+
config: SwissAIConfig,
220221
cache_config: Optional[CacheConfig] = None,
221222
quant_config: Optional[QuantizationConfig] = None,
222223
prefix: str = "",

0 commit comments

Comments
 (0)