Skip to content

Commit

Permalink
llama : add Qwen support (ggerganov#4281)
Browse files Browse the repository at this point in the history
* enable qwen to llama.cpp

* llama : do not GPU split bias tensors

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
2 people authored and mounta11n committed Dec 19, 2023
1 parent 7d6e02a commit 4a324d0
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 9 deletions.
131 changes: 130 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -168,6 +168,8 @@ def from_model_architecture(model_architecture):
return PersimmonModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -203,6 +205,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -832,6 +836,131 @@ def set_gguf_parameters(self):
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(1e-5)


class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
byte_encoder = bytes_to_unicode()
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])

@staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts

def set_vocab(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []

from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["vocab_size"]
assert max(tokenizer.get_vocab().values()) < vocab_size

merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[self.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) == 2
merges.append(' '.join(map(self.token_bytes_to_string, merged)))

reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.special_tokens

for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode("utf-8")
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.CONTROL)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.merges = merges
special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
self.gguf_writer.add_name("Qwen")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])

def write_tensors(self):
block_count = self.hparams["num_hidden_layers"]
model_kv = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)

###### CONVERSION LOGIC ######


Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class MODEL_ARCH(IntEnum):
BERT = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -132,6 +133,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -317,6 +319,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.QWEN: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO
],
Expand All @@ -336,6 +352,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.PERSIMMON: [
MODEL_TENSOR.ROPE_FREQS,
],
MODEL_ARCH.QWEN: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down
18 changes: 10 additions & 8 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact
"transformer.wte", # gpt2 gpt-j mpt refact qwen
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
Expand Down Expand Up @@ -38,7 +38,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom
"word_embeddings_for_head", # persimmon
),
Expand All @@ -51,7 +51,7 @@ class TensorNameMap:
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
"ln_f", # refact bloom
"ln_f", # refact bloom qwen
"language_model.encoder.final_layernorm", # persimmon
),

Expand All @@ -65,7 +65,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
Expand All @@ -85,7 +85,7 @@ class TensorNameMap:
# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
Expand Down Expand Up @@ -119,7 +119,7 @@ class TensorNameMap:
# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
Expand All @@ -139,7 +139,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
Expand All @@ -161,18 +161,20 @@ class TensorNameMap:
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen
),

# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
),

# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
Expand Down
Loading

0 comments on commit 4a324d0

Please sign in to comment.