|
23 | 23 | # limitations under the License.
|
24 | 24 | """Inference-only deci model compatible with HuggingFace weights."""
|
25 | 25 | from collections.abc import Iterable
|
26 |
| -from typing import Optional, Union |
| 26 | +from typing import Any, Optional, Union |
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from torch import nn
|
30 | 30 | from transformers import LlamaConfig
|
31 | 31 |
|
| 32 | +from vllm.attention import AttentionType |
32 | 33 | from vllm.compilation.decorators import support_torch_compile
|
33 | 34 | from vllm.config import CacheConfig, VllmConfig
|
34 | 35 | from vllm.distributed import get_pp_group
|
35 | 36 | from vllm.model_executor.layers.layernorm import RMSNorm
|
36 | 37 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
37 | 38 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 39 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
38 | 40 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
39 | 41 | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
40 | 42 | from vllm.model_executor.model_loader.weight_utils import (
|
@@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
|
62 | 64 | return n + k - (n % k)
|
63 | 65 |
|
64 | 66 |
|
| 67 | +class DeciLMAttention(LlamaAttention): |
| 68 | + |
| 69 | + def __init__( |
| 70 | + self, |
| 71 | + config: LlamaConfig, |
| 72 | + hidden_size: int, |
| 73 | + num_heads: int, |
| 74 | + num_kv_heads: int, |
| 75 | + rope_theta: float = 10000, |
| 76 | + rope_scaling: Optional[dict[str, Any]] = None, |
| 77 | + max_position_embeddings: int = 8192, |
| 78 | + quant_config: Optional[QuantizationConfig] = None, |
| 79 | + bias: bool = False, |
| 80 | + bias_o_proj: bool = False, |
| 81 | + cache_config: Optional[CacheConfig] = None, |
| 82 | + prefix: str = "", |
| 83 | + attn_type: str = AttentionType.DECODER, |
| 84 | + ) -> None: |
| 85 | + super().__init__(config, hidden_size, num_heads, num_kv_heads, |
| 86 | + rope_theta, rope_scaling, max_position_embeddings, |
| 87 | + quant_config, bias, bias_o_proj, cache_config, prefix, |
| 88 | + attn_type) |
| 89 | + |
| 90 | + def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], |
| 91 | + quant_config: Optional[QuantizationConfig]) -> None: |
| 92 | + # Enables YARN for Mistral and LLaMA4 derivatives. |
| 93 | + is_neox_style = True |
| 94 | + if hasattr(config, "position_embedding_type"): |
| 95 | + is_neox_style = config.position_embedding_type not in [ |
| 96 | + "mistral_yarn", "rope_llama4" |
| 97 | + ] |
| 98 | + |
| 99 | + self.rotary_emb = get_rope( |
| 100 | + self.head_dim, |
| 101 | + rotary_dim=self.head_dim, |
| 102 | + max_position=self.max_position_embeddings, |
| 103 | + base=self.rope_theta, |
| 104 | + rope_scaling=rope_scaling, |
| 105 | + is_neox_style=is_neox_style, |
| 106 | + partial_rotary_factor=self.partial_rotary_factor) |
| 107 | + |
| 108 | + |
65 | 109 | class DeciLMDecoderLayer(nn.Module):
|
66 | 110 |
|
67 | 111 | def __init__(
|
@@ -98,7 +142,7 @@ def __init__(
|
98 | 142 | if not self._is_no_op_attention:
|
99 | 143 | num_kv_heads = (config.num_attention_heads //
|
100 | 144 | block_config.attention.n_heads_in_group)
|
101 |
| - self.self_attn = LlamaAttention( |
| 145 | + self.self_attn = DeciLMAttention( |
102 | 146 | config=config,
|
103 | 147 | hidden_size=self.hidden_size,
|
104 | 148 | num_heads=config.num_attention_heads,
|
|
0 commit comments