Skip to content

Commit 6d68030

Browse files
authored
[Model] Add support for YARN in NemotronNAS models (#18427)
Signed-off-by: Nave Assaf <nassaf@nvidia.com>
1 parent 5a2c76c commit 6d68030

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

vllm/model_executor/models/llama.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,9 @@ def __init__(
162162
prefix=f"{prefix}.o_proj",
163163
)
164164

165-
is_neox_style = True
166-
is_gguf = quant_config and quant_config.get_name() == "gguf"
167-
if is_gguf and config.model_type == "llama":
168-
is_neox_style = False
169-
170-
self.rotary_emb = get_rope(
171-
self.head_dim,
172-
rotary_dim=self.head_dim,
173-
max_position=max_position_embeddings,
174-
base=rope_theta,
175-
rope_scaling=rope_scaling,
176-
is_neox_style=is_neox_style,
177-
partial_rotary_factor=self.partial_rotary_factor,
178-
)
165+
self._init_rotary_emb(config,
166+
rope_scaling=rope_scaling,
167+
quant_config=quant_config)
179168

180169
if hasattr(config, "interleaved_sliding_window"):
181170
interleaved_sliding_window = config.interleaved_sliding_window
@@ -214,6 +203,24 @@ def forward(
214203
output, _ = self.o_proj(attn_output)
215204
return output
216205

206+
def _init_rotary_emb(self, config: LlamaConfig,
207+
rope_scaling: Optional[dict[str, Any]],
208+
quant_config: Optional[QuantizationConfig]) -> None:
209+
is_neox_style = True
210+
is_gguf = quant_config and quant_config.get_name() == "gguf"
211+
if is_gguf and self.config.model_type == "llama":
212+
is_neox_style = False
213+
214+
self.rotary_emb = get_rope(
215+
self.head_dim,
216+
rotary_dim=self.head_dim,
217+
max_position=self.max_position_embeddings,
218+
base=self.rope_theta,
219+
rope_scaling=rope_scaling,
220+
is_neox_style=is_neox_style,
221+
partial_rotary_factor=self.partial_rotary_factor,
222+
)
223+
217224

218225
class LlamaDecoderLayer(nn.Module):
219226

vllm/model_executor/models/nemotron_nas.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@
2323
# limitations under the License.
2424
"""Inference-only deci model compatible with HuggingFace weights."""
2525
from collections.abc import Iterable
26-
from typing import Optional, Union
26+
from typing import Any, Optional, Union
2727

2828
import torch
2929
from torch import nn
3030
from transformers import LlamaConfig
3131

32+
from vllm.attention import AttentionType
3233
from vllm.compilation.decorators import support_torch_compile
3334
from vllm.config import CacheConfig, VllmConfig
3435
from vllm.distributed import get_pp_group
3536
from vllm.model_executor.layers.layernorm import RMSNorm
3637
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3738
from vllm.model_executor.layers.quantization import QuantizationConfig
39+
from vllm.model_executor.layers.rotary_embedding import get_rope
3840
from vllm.model_executor.layers.vocab_parallel_embedding import (
3941
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4042
from vllm.model_executor.model_loader.weight_utils import (
@@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
6264
return n + k - (n % k)
6365

6466

67+
class DeciLMAttention(LlamaAttention):
68+
69+
def __init__(
70+
self,
71+
config: LlamaConfig,
72+
hidden_size: int,
73+
num_heads: int,
74+
num_kv_heads: int,
75+
rope_theta: float = 10000,
76+
rope_scaling: Optional[dict[str, Any]] = None,
77+
max_position_embeddings: int = 8192,
78+
quant_config: Optional[QuantizationConfig] = None,
79+
bias: bool = False,
80+
bias_o_proj: bool = False,
81+
cache_config: Optional[CacheConfig] = None,
82+
prefix: str = "",
83+
attn_type: str = AttentionType.DECODER,
84+
) -> None:
85+
super().__init__(config, hidden_size, num_heads, num_kv_heads,
86+
rope_theta, rope_scaling, max_position_embeddings,
87+
quant_config, bias, bias_o_proj, cache_config, prefix,
88+
attn_type)
89+
90+
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
91+
quant_config: Optional[QuantizationConfig]) -> None:
92+
# Enables YARN for Mistral and LLaMA4 derivatives.
93+
is_neox_style = True
94+
if hasattr(config, "position_embedding_type"):
95+
is_neox_style = config.position_embedding_type not in [
96+
"mistral_yarn", "rope_llama4"
97+
]
98+
99+
self.rotary_emb = get_rope(
100+
self.head_dim,
101+
rotary_dim=self.head_dim,
102+
max_position=self.max_position_embeddings,
103+
base=self.rope_theta,
104+
rope_scaling=rope_scaling,
105+
is_neox_style=is_neox_style,
106+
partial_rotary_factor=self.partial_rotary_factor)
107+
108+
65109
class DeciLMDecoderLayer(nn.Module):
66110

67111
def __init__(
@@ -98,7 +142,7 @@ def __init__(
98142
if not self._is_no_op_attention:
99143
num_kv_heads = (config.num_attention_heads //
100144
block_config.attention.n_heads_in_group)
101-
self.self_attn = LlamaAttention(
145+
self.self_attn = DeciLMAttention(
102146
config=config,
103147
hidden_size=self.hidden_size,
104148
num_heads=config.num_attention_heads,

0 commit comments

Comments
 (0)