Skip to content

Commit

Permalink
add support for chatglm-6b-32k
Browse files Browse the repository at this point in the history
  • Loading branch information
canghongjian committed Aug 4, 2023
1 parent c7d2e14 commit 0c1fe0a
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py

"""Inference-only ChatGLM model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""

import math
Expand All @@ -31,7 +31,6 @@
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput


# vllm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
Expand All @@ -57,18 +56,17 @@

logger = logging.get_logger(__name__)


CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
"THUDM/chatglm2-6b",
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
]

def build_rotary_pos(seq_len, n_elem, dtype, device, base: int = 10000):
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
def build_rotary_pos(seq_len, n_elem, dtype, device, base: int = 10000, rope_ratio: int=1):
theta = 1.0 / (base**(
torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
seq_idx = torch.arange(seq_len, dtype=dtype, device=device) / rope_ratio

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
Expand All @@ -80,8 +78,10 @@ def build_rotary_pos(seq_len, n_elem, dtype, device, base: int = 10000):
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
return cache


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb(x: torch.Tensor,
rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, np, hn]
x = x.unsqueeze(1)
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
Expand All @@ -93,14 +93,17 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
xshaped[..., 0] * rope_cache[..., 0] -
xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] +
xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)


class MLP(torch.nn.Module):
"""MLP.
Expand All @@ -115,12 +118,13 @@ def __init__(self, config):
self.add_bias = config.add_bias_linear

# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.ffn_hidden_size * 2,
bias=self.add_bias,
gather_output=False,
perform_initialization=False,
params_dtype=config.torch_dtype)
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size * 2,
bias=self.add_bias,
gather_output=False,
perform_initialization=False,
params_dtype=config.torch_dtype)


def swiglu(x):
Expand Down Expand Up @@ -200,6 +204,7 @@ def __init__(self, config, layer_number):

self.norm_factor = self.hidden_size_per_attention_head ** -0.5

self.rope_ratio = 1 if ('rope_ratio' not in config.to_dict()) else config.rope_ratio
self.multi_query_attention = config.multi_query_attention
self.seq_length = config.seq_length
self.rotary_dim = (
Expand Down Expand Up @@ -286,7 +291,7 @@ def forward(
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

k_cache, v_cache = kv_cache
rotary_pos = build_rotary_pos(self.seq_length, self.rotary_dim // 2, dtype=query_layer.dtype, device=query_layer.device)
rotary_pos = build_rotary_pos(self.seq_length, self.rotary_dim // 2, dtype=query_layer.dtype, device=query_layer.device, rope_ratio=self.rope_ratio)
if positions is not None:
rotary_pos = rotary_pos[positions]
else:
Expand Down

0 comments on commit 0c1fe0a

Please sign in to comment.