Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions umbrella/attn/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def compute_attention(self,
layer_idx,
storage_ids :torch.Tensor = None,
attention_mask :torch.Tensor = None,
logits_soft_cap = 0):
logits_soft_cap = 0,
sm_scale=None):

key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids)

Expand All @@ -83,7 +84,8 @@ def compute_attention(self,
kv_layout="NHD",
custom_mask=attention_mask[:,:self.kv_offset],
allow_fp16_qk_reduction=True,
logits_soft_cap = logits_soft_cap
logits_soft_cap = logits_soft_cap,
sm_scale=sm_scale,
)

else:
Expand Down
6 changes: 5 additions & 1 deletion umbrella/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph, QwenFBGEMM, QwenFBGEMMOffload
from .gemma import Gemma2
from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq
from .granite import Granite
class AutoModelLM:
"""
自动模型加载器,根据模型类型动态加载对应的类。
Expand Down Expand Up @@ -140,7 +141,10 @@ class AutoModelLM:
"mistralai/Mistral-Small-24B-Instruct-2501": Mistral,
"stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq,
"PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq,
"mistralai/Ministral-8B-Instruct-2410": Mistral
"mistralai/Ministral-8B-Instruct-2410": Mistral,
"ibm-granite/granite-3.2-8b-instruct-preview": Granite,
"ibm-granite/granite-3.1-8b-instruct": Granite,

}

_CUDAGRAPH_MODEL_MAPPING = {
Expand Down
180 changes: 180 additions & 0 deletions umbrella/models/granite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from transformers import GraniteForCausalLM, GraniteConfig, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import gc
import flashinfer
from ..attn.cache import KV_Cache, StaticKV_Cache
from .granite_layer import GraniteLayer
from .base import LLMBase
from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph
from tqdm import tqdm

class Granite(LLMBase):
def __init__(self,
model_name: str,
batch_size :int = 1,
max_length :int = 256,
device :str = 'cuda:0',
dtype = torch.float16) -> None:

super().__init__()
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.config = GraniteConfig.from_pretrained(model_name)
self.model_name = model_name
self.max_length = max_length

self.hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = self.config.max_position_embeddings
self.rope_theta = self.config.rope_theta
self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id]

# granite specific
self.residual_multiplier = self.config.residual_multiplier
self.embedding_multiplier = self.config.embedding_multiplier
self.attention_multiplier = self.config.attention_multiplier
self.logits_scaling = self.config.logits_scaling

def alloc(self, **kwargs):


self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size)

hf_model = GraniteForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype)

# Initialize embedding and language modeling head
self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device)
self.lm_head = hf_model.lm_head.weight.detach().to(self.device)

# Prepare rotary embeddings
position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device)

# Compute cos and sin caches
rotary_emb = hf_model.model.rotary_emb
inv_freq = rotary_emb.inv_freq.detach().to(self.device)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cache = emb.cos()[0]
self.sin_cache = emb.sin()[0]

# Initialize norm
self.norm_weight = hf_model.model.norm.weight.detach().to(self.device)
self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon

# Initialize layers
self.layers: list[GraniteLayer] = []
for idx, hf_layer in enumerate(hf_model.model.layers):
layer = GraniteLayer(idx)
layer.init_parameters(hf_layer)
layer.to(self.device)
self.layers.append(layer)
hf_model.model.layers[idx] = None
gc.collect()

self.num_layers = len(self.layers)


@torch.inference_mode()
def layer_compute(self,
buffer: GraniteLayer,
layer_idx :int,
hidden_states: torch.FloatTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):

residual = hidden_states
bsz, q_len, _ = hidden_states.size()

hidden_states = layer_norm(hidden_states,
buffer.input_layernorm_variance_epsilon,
buffer.input_layernorm_weight
)

bsz, q_len, _ = hidden_states.size()
query_states = F.linear(hidden_states, buffer.wq)
key_states = F.linear(hidden_states, buffer.wk)
value_states = F.linear(hidden_states, buffer.wv)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)


query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
self.cos_cache,
self.sin_cache,
position_ids
)

hidden_states = self.kv_cache.compute_attention(
query_states.to(value_states.dtype),
key_states.to(value_states.dtype),
value_states,
layer_idx,
storage_ids,
attention_mask,
sm_scale=self.attention_multiplier
)

hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size)
hidden_states = F.linear(hidden_states, buffer.wo)
hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states

hidden_states = layer_norm(hidden_states,
buffer.post_attention_layernorm_variance_epsilon,
buffer.post_attention_layernorm_weight
)

up = F.linear(hidden_states, buffer.up_proj)
gate = F.linear(hidden_states, buffer.gate_proj)
gate = F.silu(gate)

hidden_states = gate * up
hidden_states = F.linear(hidden_states, buffer.down_proj)

hidden_states = residual + hidden_states * self.residual_multiplier

return hidden_states


@torch.inference_mode()
def inference(self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
storage_ids: torch.LongTensor):

hidden_states = F.embedding(input_ids, self.embed_tokens) * self.embedding_multiplier
for idx in range(self.num_layers):
hidden_states = self.layer_compute(
self.layers[idx],
idx, hidden_states,
position_ids,
attention_mask,
storage_ids
)


hidden_states = layer_norm(hidden_states, self.norm_variance_epsilon, self.norm_weight)
logits = F.linear(hidden_states, self.lm_head).float() / self.logits_scaling
return logits

def gather_kv_incremental(self, indices: torch.LongTensor, offset:int):

self.kv_cache.gather_kv_incremental(indices=indices, offset=offset)

def clear(self):

self.kv_cache.clear()
92 changes: 92 additions & 0 deletions umbrella/models/granite_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations
import torch
from transformers.models.granite.modeling_granite import GraniteDecoderLayer
from ..quantization.awq_utils import AwqLinear

class GraniteLayer:
def __init__(self, layer_idx, device = "cpu") -> None:

self.wq :torch.Tensor = None
self.wk :torch.Tensor = None
self.wv :torch.Tensor = None
self.wo :torch.Tensor = None

self.gate_proj :torch.Tensor = None
self.up_proj :torch.Tensor = None
self.down_proj :torch.Tensor = None

self.input_layernorm_weight: torch.Tensor = None
self.input_layernorm_variance_epsilon: float = 1e-05

self.post_attention_layernorm_weight: torch.Tensor = None
self.post_attention_layernorm_variance_epsilon: float = 1e-05

self.layer_idx = layer_idx
self.device = device

def init_parameters(self, hf_layer: GraniteDecoderLayer):

self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach()
self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach()
self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach()
self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach()

self.gate_proj = hf_layer.mlp.gate_proj.weight.detach()
self.up_proj = hf_layer.mlp.up_proj.weight.detach()
self.down_proj = hf_layer.mlp.down_proj.weight.detach()

# Layer norm weights and epsilon
self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach()
self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon

self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach()
self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon

def to(self, device:str = 'cuda:0', non_blocking = True):

self.device = device
self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking)
self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking)
self.wq = self.wq.to(device, non_blocking=non_blocking)
self.wk = self.wk.to(device, non_blocking=non_blocking)
self.wv = self.wv.to(device, non_blocking=non_blocking)
self.wo = self.wo.to(device, non_blocking=non_blocking)
self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking)
self.up_proj = self.up_proj.to(device, non_blocking=non_blocking)
self.down_proj = self.down_proj.to(device, non_blocking=non_blocking)

def copy(self, layer: GraniteDecoderLayer):

self.wq.copy_(layer.wq, non_blocking=True)
self.wk.copy_(layer.wk, non_blocking=True)
self.wv.copy_(layer.wv, non_blocking=True)
self.wo.copy_(layer.wo, non_blocking=True)


self.gate_proj.copy_(layer.gate_proj, non_blocking=True)
self.up_proj.copy_(layer.up_proj, non_blocking=True)
self.down_proj.copy_(layer.down_proj, non_blocking=True)

# Copy layer norm weights
self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True)
self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True)

# Copy epsilon values
self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon
self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon

def alloc_space(self, layer: GraniteDecoderLayer, device):

self.device = device
self.wq = torch.zeros_like(layer.wq).to(device)
self.wk = torch.zeros_like(layer.wk).to(device)
self.wv = torch.zeros_like(layer.wv).to(device)
self.wo = torch.zeros_like(layer.wo).to(device)


self.gate_proj = torch.zeros_like(layer.gate_proj).to(device)
self.up_proj = torch.zeros_like(layer.up_proj).to(device)
self.down_proj = torch.zeros_like(layer.down_proj).to(device)

self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device)
self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device)
9 changes: 8 additions & 1 deletion umbrella/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

'gemma2': "{}",
'mistral': "[INST] {} [/INST]",
'qwq': "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n"
'qwq': "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n",
'ibm-granite': """\n<|start_of_role|>user<|end_of_role|>{}<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>"""

}

Expand All @@ -39,12 +41,17 @@
""",
'gemma2': "",
'gemma2-it': "",

'mistral': """<s>[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
Your knowledge base was last updated on 2023-10-01. The current date is 2025-03-07.

When you're not sure about some information, you say that you don't have the information and don't make up anything.
If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?")[/SYSTEM_PROMPT]""",
'qwq': "",

'ibm-granite': """<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.
Today's Date: March 05, 2025.
You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>"""

}

Expand Down