|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +# Copyright 2025 The Zhipu AI team. |
| 4 | +# Copyright 2023 The vLLM team. |
| 5 | +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. |
| 6 | +# |
| 7 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 8 | +# and OPT implementations in this library. It has been modified from its |
| 9 | +# original forms to accommodate minor architectural differences compared |
| 10 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 11 | +# |
| 12 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 13 | +# you may not use this file except in compliance with the License. |
| 14 | +# You may obtain a copy of the License at |
| 15 | +# |
| 16 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 17 | +# |
| 18 | +# Unless required by applicable law or agreed to in writing, software |
| 19 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 20 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 21 | +# See the License for the specific language governing permissions and |
| 22 | +# limitations under the License. |
| 23 | +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" |
| 24 | +from typing import Iterable, Optional, Set, Tuple, Union |
| 25 | + |
| 26 | +import torch |
| 27 | +from torch import nn |
| 28 | +from transformers import Glm4Config |
| 29 | + |
| 30 | +from vllm.attention import Attention, AttentionType |
| 31 | +from vllm.compilation.decorators import support_torch_compile |
| 32 | +from vllm.config import CacheConfig, VllmConfig |
| 33 | +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
| 34 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 35 | +from vllm.model_executor.layers.linear import (QKVParallelLinear, |
| 36 | + RowParallelLinear) |
| 37 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 38 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 39 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
| 40 | +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
| 41 | +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
| 42 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 43 | +from vllm.sequence import IntermediateTensors |
| 44 | + |
| 45 | +from .interfaces import SupportsLoRA, SupportsPP |
| 46 | +from .llama import LlamaMLP as Glm4MLP |
| 47 | +from .llama import LlamaModel |
| 48 | +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix |
| 49 | + |
| 50 | + |
| 51 | +class Glm4Attention(nn.Module): |
| 52 | + |
| 53 | + def __init__(self, |
| 54 | + config: Glm4Config, |
| 55 | + hidden_size: int, |
| 56 | + num_heads: int, |
| 57 | + num_kv_heads: int, |
| 58 | + max_position: int = 4096 * 32, |
| 59 | + head_dim: Optional[int] = None, |
| 60 | + qkv_bias: bool = False, |
| 61 | + rope_theta: float = 10000, |
| 62 | + cache_config: Optional[CacheConfig] = None, |
| 63 | + quant_config: Optional[QuantizationConfig] = None, |
| 64 | + rope_scaling: Optional[Tuple] = None, |
| 65 | + prefix: str = "", |
| 66 | + attn_type: str = AttentionType.DECODER) -> None: |
| 67 | + super().__init__() |
| 68 | + self.hidden_size = hidden_size |
| 69 | + tp_size = get_tensor_model_parallel_world_size() |
| 70 | + self.total_num_heads = num_heads |
| 71 | + assert self.total_num_heads % tp_size == 0 |
| 72 | + self.num_heads = self.total_num_heads // tp_size |
| 73 | + self.total_num_kv_heads = num_kv_heads |
| 74 | + if self.total_num_kv_heads >= tp_size: |
| 75 | + # Number of KV heads is greater than TP size, so we partition |
| 76 | + # the KV heads across multiple tensor parallel GPUs. |
| 77 | + assert self.total_num_kv_heads % tp_size == 0 |
| 78 | + else: |
| 79 | + # Number of KV heads is less than TP size, so we replicate |
| 80 | + # the KV heads across multiple tensor parallel GPUs. |
| 81 | + assert tp_size % self.total_num_kv_heads == 0 |
| 82 | + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) |
| 83 | + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
| 84 | + self.head_dim = head_dim or hidden_size // self.total_num_heads |
| 85 | + self.rotary_dim = int(partial_rotary_factor * self.head_dim) |
| 86 | + self.q_size = self.num_heads * self.head_dim |
| 87 | + self.kv_size = self.num_kv_heads * self.head_dim |
| 88 | + self.scaling = self.head_dim**-0.5 |
| 89 | + self.rope_theta = rope_theta |
| 90 | + self.qkv_proj = QKVParallelLinear( |
| 91 | + hidden_size, |
| 92 | + self.head_dim, |
| 93 | + self.total_num_heads, |
| 94 | + self.total_num_kv_heads, |
| 95 | + bias=qkv_bias, |
| 96 | + quant_config=quant_config, |
| 97 | + prefix=f"{prefix}.qkv_proj", |
| 98 | + ) |
| 99 | + self.o_proj = RowParallelLinear( |
| 100 | + self.total_num_heads * self.head_dim, |
| 101 | + hidden_size, |
| 102 | + bias=False, |
| 103 | + quant_config=quant_config, |
| 104 | + prefix=f"{prefix}.o_proj", |
| 105 | + ) |
| 106 | + self.rotary_emb = get_rope( |
| 107 | + self.head_dim, |
| 108 | + rotary_dim=self.rotary_dim, |
| 109 | + max_position=max_position, |
| 110 | + base=self.rope_theta, |
| 111 | + rope_scaling=rope_scaling, |
| 112 | + partial_rotary_factor=partial_rotary_factor, |
| 113 | + ) |
| 114 | + self.attn = Attention(self.num_heads, |
| 115 | + self.head_dim, |
| 116 | + self.scaling, |
| 117 | + num_kv_heads=self.num_kv_heads, |
| 118 | + cache_config=cache_config, |
| 119 | + quant_config=quant_config, |
| 120 | + prefix=f"{prefix}.attn", |
| 121 | + attn_type=attn_type) |
| 122 | + |
| 123 | + def forward( |
| 124 | + self, |
| 125 | + positions: torch.Tensor, |
| 126 | + hidden_states: torch.Tensor, |
| 127 | + ) -> torch.Tensor: |
| 128 | + qkv, _ = self.qkv_proj(hidden_states) |
| 129 | + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
| 130 | + q, k = self.rotary_emb(positions, q, k) |
| 131 | + attn_output = self.attn(q, k, v) |
| 132 | + output, _ = self.o_proj(attn_output) |
| 133 | + return output |
| 134 | + |
| 135 | + |
| 136 | +class Glm4DecoderLayer(nn.Module): |
| 137 | + |
| 138 | + def __init__( |
| 139 | + self, |
| 140 | + config: Glm4Config, |
| 141 | + cache_config: Optional[CacheConfig] = None, |
| 142 | + quant_config: Optional[QuantizationConfig] = None, |
| 143 | + prefix: str = "", |
| 144 | + ) -> None: |
| 145 | + super().__init__() |
| 146 | + self.hidden_size = config.hidden_size |
| 147 | + rope_theta = getattr(config, "rope_theta", 1000000) |
| 148 | + rope_scaling = getattr(config, "rope_scaling", None) |
| 149 | + |
| 150 | + self.self_attn = Glm4Attention( |
| 151 | + config=config, |
| 152 | + hidden_size=self.hidden_size, |
| 153 | + num_heads=config.num_attention_heads, |
| 154 | + max_position=config.max_position_embeddings, |
| 155 | + num_kv_heads=config.num_key_value_heads, |
| 156 | + rope_theta=rope_theta, |
| 157 | + qkv_bias=getattr(config, 'attention_bias', False), |
| 158 | + head_dim=getattr(config, 'head_dim', None), |
| 159 | + cache_config=cache_config, |
| 160 | + quant_config=quant_config, |
| 161 | + rope_scaling=rope_scaling, |
| 162 | + prefix=f"{prefix}.self_attn", |
| 163 | + attn_type=AttentionType.DECODER, |
| 164 | + ) |
| 165 | + self.mlp = Glm4MLP( |
| 166 | + hidden_size=self.hidden_size, |
| 167 | + intermediate_size=config.intermediate_size, |
| 168 | + hidden_act=config.hidden_act, |
| 169 | + quant_config=quant_config, |
| 170 | + prefix=f"{prefix}.mlp", |
| 171 | + ) |
| 172 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 173 | + eps=config.rms_norm_eps) |
| 174 | + self.post_attention_layernorm = RMSNorm(config.hidden_size, |
| 175 | + eps=config.rms_norm_eps) |
| 176 | + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, |
| 177 | + eps=config.rms_norm_eps) |
| 178 | + self.post_mlp_layernorm = RMSNorm(config.hidden_size, |
| 179 | + eps=config.rms_norm_eps) |
| 180 | + |
| 181 | + def forward( |
| 182 | + self, |
| 183 | + positions: torch.Tensor, |
| 184 | + hidden_states: torch.Tensor, |
| 185 | + residual: Optional[torch.Tensor], |
| 186 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 187 | + # Self Attention |
| 188 | + if residual is None: |
| 189 | + residual = hidden_states |
| 190 | + hidden_states = self.input_layernorm(hidden_states) |
| 191 | + else: |
| 192 | + hidden_states, residual = self.input_layernorm( |
| 193 | + hidden_states, residual) |
| 194 | + hidden_states = self.self_attn( |
| 195 | + positions=positions, |
| 196 | + hidden_states=hidden_states, |
| 197 | + ) |
| 198 | + |
| 199 | + hidden_states = self.post_self_attn_layernorm(hidden_states) |
| 200 | + hidden_states = residual + hidden_states |
| 201 | + |
| 202 | + # Fully Connected |
| 203 | + hidden_states = self.post_attention_layernorm(hidden_states, residual) |
| 204 | + hidden_states = self.mlp(hidden_states) |
| 205 | + hidden_states = self.post_mlp_layernorm(hidden_states) |
| 206 | + hidden_states = residual + hidden_states |
| 207 | + |
| 208 | + return hidden_states, residual |
| 209 | + |
| 210 | + |
| 211 | +ALL_DECODER_LAYER_TYPES = { |
| 212 | + "attention": Glm4DecoderLayer, |
| 213 | +} |
| 214 | + |
| 215 | + |
| 216 | +@support_torch_compile( |
| 217 | + dynamic_arg_dims={ |
| 218 | + "input_ids": 0, |
| 219 | + "positions": -1, |
| 220 | + "intermediate_tensors": 0, |
| 221 | + "inputs_embeds": 0, |
| 222 | + }) |
| 223 | +class Glm4Model(LlamaModel): |
| 224 | + |
| 225 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 226 | + super().__init__(vllm_config=vllm_config, |
| 227 | + prefix=prefix, |
| 228 | + layer_type=Glm4DecoderLayer) |
| 229 | + |
| 230 | + |
| 231 | +class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
| 232 | + packed_modules_mapping = { |
| 233 | + "qkv_proj": [ |
| 234 | + "q_proj", |
| 235 | + "k_proj", |
| 236 | + "v_proj", |
| 237 | + ], |
| 238 | + "gate_up_proj": [ |
| 239 | + "gate_proj", |
| 240 | + "up_proj", |
| 241 | + ], |
| 242 | + } |
| 243 | + |
| 244 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 245 | + super().__init__() |
| 246 | + config = vllm_config.model_config.hf_config |
| 247 | + quant_config = vllm_config.quant_config |
| 248 | + lora_config = vllm_config.lora_config |
| 249 | + |
| 250 | + self.config = config |
| 251 | + self.lora_config = lora_config |
| 252 | + |
| 253 | + self.quant_config = quant_config |
| 254 | + self.model = Glm4Model(vllm_config=vllm_config, |
| 255 | + prefix=maybe_prefix(prefix, "model")) |
| 256 | + |
| 257 | + if get_pp_group().is_last_rank: |
| 258 | + if config.tie_word_embeddings: |
| 259 | + self.lm_head = self.model.embed_tokens |
| 260 | + else: |
| 261 | + self.lm_head = ParallelLMHead(config.vocab_size, |
| 262 | + config.hidden_size, |
| 263 | + quant_config=quant_config, |
| 264 | + prefix=maybe_prefix( |
| 265 | + prefix, "lm_head")) |
| 266 | + else: |
| 267 | + self.lm_head = PPMissingLayer() |
| 268 | + |
| 269 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 270 | + self.sampler = get_sampler() |
| 271 | + |
| 272 | + self.make_empty_intermediate_tensors = ( |
| 273 | + self.model.make_empty_intermediate_tensors) |
| 274 | + |
| 275 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 276 | + return self.model.get_input_embeddings(input_ids) |
| 277 | + |
| 278 | + def forward( |
| 279 | + self, |
| 280 | + input_ids: torch.Tensor, |
| 281 | + positions: torch.Tensor, |
| 282 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 283 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 284 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 285 | + hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| 286 | + inputs_embeds) |
| 287 | + return hidden_states |
| 288 | + |
| 289 | + def compute_logits( |
| 290 | + self, |
| 291 | + hidden_states: torch.Tensor, |
| 292 | + sampling_metadata: SamplingMetadata, |
| 293 | + ) -> Optional[torch.Tensor]: |
| 294 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 295 | + sampling_metadata) |
| 296 | + return logits |
| 297 | + |
| 298 | + def sample( |
| 299 | + self, |
| 300 | + logits: torch.Tensor, |
| 301 | + sampling_metadata: SamplingMetadata, |
| 302 | + ) -> Optional[SamplerOutput]: |
| 303 | + next_tokens = self.sampler(logits, sampling_metadata) |
| 304 | + return next_tokens |
| 305 | + |
| 306 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 307 | + torch.Tensor]]) -> Set[str]: |
| 308 | + loader = AutoWeightsLoader( |
| 309 | + self, |
| 310 | + skip_prefixes=(["lm_head."] |
| 311 | + if self.config.tie_word_embeddings else None), |
| 312 | + ) |
| 313 | + return loader.load_weights(weights) |
0 commit comments