Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion umbrella/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph
from .gemma import Gemma2
from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq
from .glm4 import Glm4
class AutoModelLM:
"""
自动模型加载器,根据模型类型动态加载对应的类。
Expand Down Expand Up @@ -117,7 +118,8 @@ class AutoModelLM:
"mistralai/Mistral-Small-24B-Instruct-2501": Mistral,
"stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq,
"PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq,
"mistralai/Ministral-8B-Instruct-2410": Mistral
"mistralai/Ministral-8B-Instruct-2410": Mistral,
"THUDM/glm-4-9b-chat-hf": Glm4,
}

_CUDAGRAPH_MODEL_MAPPING = {
Expand Down
186 changes: 186 additions & 0 deletions umbrella/models/glm4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from transformers import GlmForCausalLM, GlmConfig
import torch
import torch.nn.functional as F
import gc
import flashinfer
from ..attn.cache import KV_Cache, StaticKV_Cache
from .glm4_layer import Glm4Layer
from .base import LLMBase
from .model_utils import layer_norm, capture_graph

from tqdm import tqdm

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)

# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed




class Glm4(LLMBase):
def __init__(self,
model_name: str,
batch_size :int = 1,
max_length :int = 256,
device :str = 'cuda:0',
dtype = torch.float16) -> None:
super().__init__()
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.config = GlmConfig.from_pretrained(model_name)
self.model_name = model_name
self.max_length = max_length
self.hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = self.config.max_position_embeddings
self.rope_theta = self.config.rope_theta
self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id]


def alloc(self, **kwargs):
self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size)
hf_model = GlmForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype)
self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device)
if self.config.tie_word_embeddings:
self.lm_head = self.embed_tokens
else:
self.lm_head = hf_model.lm_head.weight.detach().to(self.device)
self.norm_weight = hf_model.model.norm.weight.detach().to(self.device)
self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon
self.inv_freq = hf_model.model.rotary_emb.inv_freq.detach().to(self.device)
self.attention_scaling = hf_model.model.rotary_emb.attention_scaling
position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cache = emb.cos()[0]
self.sin_cache = emb.sin()[0]
self.cos_cache = self.cos_cache * self.attention_scaling
self.sin_cache = self.sin_cache * self.attention_scaling
self.cos_cache = self.cos_cache.to(self.dtype)
self.sin_cache = self.sin_cache.to(self.dtype)

self.layers :list[Glm4Layer] = []
for idx, hf_layer in enumerate(hf_model.model.layers):
layer = Glm4Layer(idx)
layer.init_parameters(hf_layer=hf_layer)
layer.to(self.device)
self.layers.append(layer)
hf_model.model.layers[idx] = None
gc.collect()

self.num_layers = len(self.layers)




@torch.inference_mode()
def layer_compute(self,
buffer: Glm4Layer,
layer_idx :int,
hidden_states: torch.FloatTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):
residual = hidden_states
bsz, q_len, _ = hidden_states.size()
hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight)
bsz, q_len, _ = hidden_states.size()

#attention
query_states = F.linear(hidden_states, buffer.wq, bias=buffer.qbias)
key_states = F.linear(hidden_states, buffer.wk, bias=buffer.kbias)
value_states = F.linear(hidden_states, buffer.wv, bias=buffer.vbias)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, self.cos_cache, self.sin_cache, position_ids)
hidden_states = self.kv_cache.compute_attention(
query_states, key_states, value_states, layer_idx, storage_ids, attention_mask
)
hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size)
hidden_states = F.linear(hidden_states, buffer.wo)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight)




# MLP
up_states = F.linear(hidden_states, buffer.gate_up_proj)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * F.silu(gate)
hidden_states = F.linear(up_states, buffer.down_proj)
hidden_states = residual + hidden_states
return hidden_states

@torch.inference_mode()
def inference(self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):
hidden_states = F.embedding(input_ids, self.embed_tokens)
for idx in range(self.num_layers):
hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids, attention_mask, storage_ids)
b, s, h = hidden_states.shape

hidden_states = hidden_states.reshape(b * s, h)
hidden_states = flashinfer.rmsnorm(hidden_states, self.norm_weight, self.norm_variance_epsilon)
hidden_states = hidden_states.reshape(b, s, h)
logits = F.linear(hidden_states, self.lm_head).float()
return logits


95 changes: 95 additions & 0 deletions umbrella/models/glm4_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations
import torch
from transformers.models.glm.modeling_glm import GlmDecoderLayer

class Glm4Layer:
def __init__(self, layer_idx, device = "cpu") -> None:

self.wq :torch.Tensor = None
self.wk :torch.Tensor = None
self.wv :torch.Tensor = None
self.wo :torch.Tensor = None
self.qbias :torch.Tensor = None
self.kbias :torch.Tensor = None
self.vbias :torch.Tensor = None
self.obias :torch.Tensor = None
self.gate_up_proj: torch.Tensor = None
self.down_proj: torch.Tensor = None

self.input_layernorm_weight :torch.Tensor = None
self.input_layernorm_variance_epsilon :float = 0.0

self.post_attention_layernorm_weight :torch.Tensor = None
self.post_attention_layernorm_variance_epsilon :float = 0.0

self.layer_idx = layer_idx
self.device = device

def init_parameters(self, hf_layer: GlmDecoderLayer):
self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach()
self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach()
self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach()
self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach()
self.qbias :torch.Tensor = hf_layer.self_attn.q_proj.bias.detach()
self.kbias :torch.Tensor = hf_layer.self_attn.k_proj.bias.detach()
self.vbias :torch.Tensor = hf_layer.self_attn.v_proj.bias.detach()
self.gate_up_proj: torch.Tensor = hf_layer.mlp.gate_up_proj.weight.detach()
self.down_proj: torch.Tensor = hf_layer.mlp.down_proj.weight.detach()

self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach()
self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon

self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach()
self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon


def to(self, device:str = 'cuda:0', non_blocking = True):

self.device = device
self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking)
self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking)
self.wq = self.wq.to(device, non_blocking=non_blocking)
self.wk = self.wk.to(device, non_blocking=non_blocking)
self.wv = self.wv.to(device, non_blocking=non_blocking)
self.wo = self.wo.to(device, non_blocking=non_blocking)
self.qbias = self.qbias.to(device, non_blocking=non_blocking)
self.kbias = self.kbias.to(device, non_blocking=non_blocking)
self.vbias = self.vbias.to(device, non_blocking=non_blocking)
self.gate_up_proj = self.gate_up_proj.to(device, non_blocking=non_blocking)
self.down_proj = self.down_proj.to(device, non_blocking=non_blocking)

def copy(self, layer: Glm4Layer):

self.wq.copy_(layer.wq, non_blocking=True)
self.wk.copy_(layer.wk, non_blocking=True)
self.wv.copy_(layer.wv, non_blocking=True)
self.wo.copy_(layer.wo, non_blocking=True)
self.qbias.copy_(layer.qbias, non_blocking=True)
self.kbias.copy_(layer.kbias, non_blocking=True)
self.vbias.copy_(layer.vbias, non_blocking=True)
self.gate_up_proj.copy_(layer.gate_up_proj, non_blocking=True)
self.down_proj.copy_(layer.down_proj, non_blocking=True)


self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True)
self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True)
self.input_layernorm_variance_epsilon= layer.input_layernorm_variance_epsilon
self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon
self.layer_idx = layer.layer_idx

def alloc_space(self, layer: Glm4Layer, device):

self.device = device
self.wq = torch.zeros_like(layer.wq).to(device)
self.wk = torch.zeros_like(layer.wk).to(device)
self.wv = torch.zeros_like(layer.wv).to(device)
self.wo = torch.zeros_like(layer.wo).to(device)
self.qbias = torch.zeros_like(layer.qbias).to(device)
self.kbias = torch.zeros_like(layer.kbias).to(device)
self.vbias = torch.zeros_like(layer.vbias).to(device)
self.gate_up_proj = torch.zeros_like(layer.gate_up_proj).to(device)
self.down_proj = torch.zeros_like(layer.down_proj).to(device)


self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device)
self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device)
9 changes: 7 additions & 2 deletions umbrella/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
""",

'gemma2': "{}",
'mistral': "[INST] {} [/INST]"

'mistral': "[INST] {} [/INST]",
'glm4': """<|user|>
{}<|assistant|>
"""
}

SysPrompts = {
Expand All @@ -39,6 +41,9 @@
'gemma2': "",
'gemma2-it': "",
'mistral': "",
'glm4': """[gMASK]<sop><|system|>
You are a helpful assistant
"""

}

Expand Down