Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLM] Add support for InternLM architecture #1835

Merged
merged 18 commits into from
Feb 28, 2024
Prev Previous commit
Next Next commit
Add files via upload
  • Loading branch information
tlopex authored Feb 25, 2024
commit 6f9789d2884f8fb63b3f0bfecf80208b72f65fdb
102 changes: 102 additions & 0 deletions python/mlc_chat/model/internlm/internlm_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
This file specifies how MLC's StableLM parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_chat.loader import ExternMapping
from mlc_chat.quantization import Quantization

from .internlm_model import InternLMConfig, InternLMForCausalLM


def huggingface(model_config: InternLMForCausalLM, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : InternLM2Config
The configuration of the InternLM2 model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = InternLMForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.wqkv_pack.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
mlc_name = f"{attn}.wqkv_pack.bias"
if mlc_name in named_parameters:
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.bias",
f"{attn}.k_proj.bias",
f"{attn}.v_proj.bias",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
250 changes: 250 additions & 0 deletions python/mlc_chat/model/internlm/internlm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
Implementation for InternLM architecture.
TODO: add docstring
"""

import dataclasses
from typing import Any, Dict, Optional

from tvm import te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op

from mlc_chat import op as op_ext
from mlc_chat.support import logging
from mlc_chat.support.config import ConfigBase
from mlc_chat.support.style import bold

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class InternLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
"""Configuration of the InternLM model."""

vocab_size: int
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
rms_norm_eps: float
intermediate_size: int
bias: bool
use_cache: bool
pad_token_id: int
bos_token_id: int
eos_token_id: int
context_window_size: int = 0
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
if self.context_window_size == 0:
for name in ["max_position_embeddings", "max_sequence_length"]:
if name in self.kwargs:
self.context_window_size = self.kwargs.pop(name)
logger.info(
"%s not found in config.json. Falling back to %s (%d)",
bold("context_window_size"),
bold(name),
self.context_window_size,
)
break
else:
raise ValueError(
"Unable to determine the maxmimum sequence length, because none of "
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %s (%d)",
bold("prefill_chunk_size"),
bold("context_window_size"),
self.context_window_size,
)
self.prefill_chunk_size = self.context_window_size
elif self.prefill_chunk_size > self.context_window_size:
logger.info(
"Overriding %s from %d to %d (%s)",
bold("prefill_chunk_size"),
self.prefill_chunk_size,
self.context_window_size,
bold("context_window_size"),
)
self.prefill_chunk_size = self.context_window_size
assert self.tensor_parallel_shards == 1, "StableLM currently does not support sharding."


# pylint: disable=invalid-name,missing-docstring


class InternLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: InternLMConfig):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.context_window_size

self.wqkv_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)

self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim])
self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim])

def forward( # pylint: disable=too-many-locals
self,
hidden_states: Tensor,
attention_mask: Tensor,
total_seq_len: tir.Var,
):
d, h, t = self.head_dim, self.num_heads, total_seq_len
b, s, _ = hidden_states.shape
assert b == 1, "Only support batch size 1 at this moment."
# Step 1. QKV Projection
qkv = self.wqkv_pack(hidden_states)
qkv = op.reshape(qkv, (b, s, 3 * h, d))
# Step 2. Apply QK rotary embedding
q, k, v = op_ext.llama_rope(qkv, t, 10000, h, h)
# Step 3. Query and update KVCache
self.k_cache.append(op.squeeze(k, axis=0))
self.v_cache.append(op.squeeze(v, axis=0))
k = self.k_cache.view(t)
v = self.v_cache.view(t)
# Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V
output = op_ext.attention(q, k, v, casual_mask=attention_mask)
# Step 5. Apply output projection
return self.o_proj(output)


class InternLMMLP(nn.Module):
def __init__(self, config: InternLMConfig):
self.gate_up_proj = nn.Linear(
in_features=config.hidden_size,
out_features=2 * config.intermediate_size,
bias=False,
)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

def forward(self, x):
concat_x1_x2 = self.gate_up_proj(x)
x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
return self.down_proj(op.silu(x1) * x2)


class InternLMDecoderLayer(nn.Module):
def __init__(self, config: InternLMConfig):
self.self_attn = InternLMAttention(config)
self.mlp = InternLMMLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)

def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var):
out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len)
hidden_states = out + hidden_states
out = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = out + hidden_states
return hidden_states


class InternLMModel(nn.Module):
def __init__(self, config: InternLMConfig):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)

def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor):
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, total_seq_len)
hidden_states = self.norm(hidden_states)
return hidden_states


class InternLMForCausalLM(nn.Module):
def __init__(self, config: InternLMConfig):
self.model = InternLMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.vocab_size = config.vocab_size
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
if dtype is not None:
self.dtype = dtype

def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor):
def _index(x: te.Tensor): # x[:-1,:]
b, s, d = x.shape
return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")

hidden_states = self.model(inputs, total_seq_len, attention_mask)
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
logits = self.lm_head(hidden_states)
if logits.dtype != "float32":
logits = logits.astype("float32")
return logits

def prefill(self, inputs: Tensor, total_seq_len: tir.Var):
def _attention_mask(batch_size, seq_len, total_seq_len):
return te.compute(
(batch_size, 1, seq_len, total_seq_len),
lambda b, _, i, j: tir.if_then_else(
i < j - (total_seq_len - seq_len),
tir.min_value(self.dtype),
tir.max_value(self.dtype),
),
name="attention_mask_prefill",
)

batch_size, seq_len = inputs.shape
attention_mask = op.tensor_expr_op(
_attention_mask,
name_hint="attention_mask_prefill",
args=[batch_size, seq_len, total_seq_len],
)
return self.forward(inputs, total_seq_len, attention_mask)

def decode(self, inputs: Tensor, total_seq_len: tir.Var):
batch_size, seq_len = inputs.shape
attention_mask = op.full(
shape=[batch_size, 1, seq_len, total_seq_len],
fill_value=tir.max_value(self.dtype),
dtype=self.dtype,
)
return self.forward(inputs, total_seq_len, attention_mask)

def softmax_with_temperature(self, logits: Tensor, temperature: Tensor):
return op.softmax(logits / temperature, axis=-1)

def get_default_spec(self):
batch_size = 1
mod_spec = {
"prefill": {
"inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"),
"total_seq_len": int,
"$": {
"param_mode": "packed",
"effect_mode": "packed",
},
},
"decode": {
"inputs": nn.spec.Tensor([batch_size, 1], "int32"),
"total_seq_len": int,
"$": {
"param_mode": "packed",
"effect_mode": "packed",
},
},
"softmax_with_temperature": {
"logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"),
"temperature": nn.spec.Tensor([], "float32"),
"$": {
"param_mode": "none",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
Loading